Сравните решатели LDA

В этом примере показано, как сравнить решатели скрытого выделения Дирихле (LDA) путем сравнения качества подгонки и время, потраченное, чтобы подбирать модель.

Импортируйте текстовые данные

Импортируйте набор кратких обзоров и подписей категорий из математических бумаг с помощью arXiv API. Задайте количество записей, чтобы импортировать использование importSize переменная.

importSize = 50000;

Создайте URL, который запрашивает записи с набором "math" и метаданные снабжают префиксом "arXiv".

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

Извлеките абстрактный текст и лексему возобновления, возвращенную запросом URL с помощью parseArXivRecords функция, которая присоединена к этому примеру как к вспомогательному файлу. Чтобы получить доступ к этому файлу, откройте этот пример как live скрипт. Обратите внимание на то, что arXiv API является ограниченным уровнем и требует ожидания между несколькими запросами.

[textData,~,resumptionToken] = parseArXivRecords(url);

Итеративно импортируйте больше фрагментов записей, пока необходимое количество не достигнуто, или больше нет записей. Чтобы продолжить импортировать записи из того, где вы кончили, используйте лексему возобновления от предыдущего результата в запросе URL. Чтобы придерживаться ограничений скорости, наложенных arXiv API, добавьте задержку 20 секунд перед каждым запросом с помощью pause функция.

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
end

Предварительно обработайте текстовые данные

Отложите 10% документов наугад для валидации.

numDocuments = numel(textData);
cvp = cvpartition(numDocuments,'HoldOut',0.1);
textDataTrain = textData(training(cvp));
textDataValidation = textData(test(cvp));

Маркируйте и предварительно обработайте текстовые данные с помощью функционального preprocessText который перечислен в конце этого примера.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Создайте модель сумки слов из учебных материалов. Удалите слова, которые не появляются больше чем два раза всего. Удалите любые документы, содержащие слова.

bag = bagOfWords(documentsTrain);
bag = removeInfrequentWords(bag,2);
bag = removeEmptyDocuments(bag);

Для данных о валидации создайте модель сумки слов из документов валидации. Вы не должны удалять слова из validaiton данных, потому что любые слова, которые не появляются в подбиравших моделях LDA, автоматически проигнорированы.

validationData = bagOfWords(documentsValidation);

Соответствуйте и сравните модели

Для каждого из решателей LDA подберите модель с 40 темами. Чтобы отличить решатели при графическом выводе результатов на тех же осях, задайте различные свойства линии для каждого решателя.

numTopics = 40;
solvers = ["cgs" "avb" "cvb0" "savb"];
lineSpecs = ["+-" "*-" "x-" "o-"];

Подбирайте модель LDA с помощью каждого решателя. Для каждого решателя задайте начальную концентрацию темы 1, чтобы подтвердить модель однажды на передачу данных и не соответствовать параметру концентрации темы. Используя данные в FitInfo свойство подбиравших моделей LDA, постройте недоумение валидации, и время протекло.

Стохастический решатель, по умолчанию, использует мини-пакетный размер 1 000 и подтверждает модель каждые 10 итераций. Для этого решателя, чтобы подтвердить модель однажды на передачу данных, устанавливают частоту валидации на ceil(numObservations/1000), где numObservations количество документов в обучающих данных. Для других решателей, набор частота валидации к 1.

Для итераций, что стохастический решатель не оценивает недоумение валидации, стохастический решатель сообщает о NaN в FitInfo свойство. Чтобы построить недоумение валидации, удалите NaNs из значений, о которых сообщают.

numObservations = bag.NumDocuments;

figure
for i = 1:numel(solvers)
    solver = solvers(i);
    lineSpec = lineSpecs(i);

    if solver == "savb"
        numIterationsPerDataPass = ceil(numObservations/1000);
    else
        numIterationsPerDataPass = 1;
    end

    mdl = fitlda(bag,numTopics, ...
        'Solver',solver, ...
        'InitialTopicConcentration',1, ...
        'FitTopicConcentration',false, ...
        'ValidationData',validationData, ...
        'ValidationFrequency',numIterationsPerDataPass, ...
        'Verbose',0);

    history = mdl.FitInfo.History;

    timeElapsed = history.TimeSinceStart;

    validationPerplexity = history.ValidationPerplexity;

    % Remove NaNs.
    idx = isnan(validationPerplexity);
    timeElapsed(idx) = [];
    validationPerplexity(idx) = [];

    plot(timeElapsed,validationPerplexity,lineSpec)
    hold on
end

hold off
xlabel("Time Elapsed (s)")
ylabel("Validation Perplexity")
ylim([0 inf])
legend(solvers)

Для стохастического решателя существует только одна точка данных. Это вызвано тем, что этот решатель проходит через входные данные однажды. Чтобы задать больше передач данных, используйте 'DataPassLimit' опция. Для пакетных решателей ("cgs", "avb", и "cvb0"), задавать количество итераций раньше подбирало модели, использовало 'IterationLimit' опция.

Более низкое недоумение валидации предлагает лучшую подгонку. Обычно, решатели "savb" и "cgs" сходитесь быстро к хорошей подгонке. Решатель "cvb0" может сходиться к лучшей подгонке, но она может взять намного дольше, чтобы сходиться.

Для FitInfo свойство, fitlda функционируйте оценивает недоумение валидации от вероятностей документа в оценках наибольшего правдоподобия вероятностей на тематику документа. Это обычно быстрее, чтобы вычислить, но может быть менее точным, чем другие методы. В качестве альтернативы вычислите недоумение валидации с помощью logp функция. Эта функция вычисляет более точные значения, но может занять больше времени, чтобы запуститься. Для примера, показывающего, как вычислить недоумение с помощью logp, смотрите Вычисляют Логарифмические Вероятности Документа из Матрицы Для подсчета количества слов.

Предварительная обработка функции

Функциональный preprocessText выполняет следующие шаги:

  1. Маркируйте текст с помощью tokenizedDocument.

  2. Lemmatize слова с помощью normalizeWords.

  3. Сотрите пунктуацию с помощью erasePunctuation.

  4. Удалите список слов остановки (такой как "и", и) использование removeStopWords.

  5. Удалите слова с 2 или меньшим количеством символов с помощью removeShortWords.

  6. Удалите слова с 15 или больше символами с помощью removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Lemmatize the words.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Remove words with 2 or fewer characters, and words with 15 or greater
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);

end

Смотрите также

| | | | | | | | | | | | |

Похожие темы