exponenta event banner

Сравнение решателей LDA

В этом примере показано, как сравнивать латентные решатели распределения Дирихле (LDA), сравнивая доброту подгонки и время, затрачиваемое на подгонку модели.

Импорт текстовых данных

Импорт набора аннотаций и меток категорий из математических документов с помощью API arXiv. Укажите количество записей для импорта с помощью importSize переменная.

importSize = 50000;

Создание URL-адреса для запроса записей с набором "math" и префикс метаданных "arXiv".

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

Извлеките абстрактный текст и маркер возобновления, возвращенный URL-адресом запроса, с помощью parseArXivRecords функция, которая присоединена к этому примеру в качестве вспомогательного файла. Чтобы получить доступ к этому файлу, откройте этот пример в реальном времени. Следует отметить, что API arXiv ограничен по скорости и требует ожидания между несколькими запросами.

[textData,~,resumptionToken] = parseArXivRecords(url);

Итеративно импортируйте больше порций записей до тех пор, пока не будет достигнута требуемая сумма или не будет больше записей. Чтобы продолжить импорт записей из того места, где вы остановились, используйте маркер возобновления из предыдущего результата в URL-адресе запроса. Чтобы придерживаться ограничений скорости, установленных API arXiv, добавьте задержку в 20 секунд перед каждым запросом, используя pause функция.

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
end

Предварительная обработка текстовых данных

Отложите 10% документов случайным образом для проверки.

numDocuments = numel(textData);
cvp = cvpartition(numDocuments,'HoldOut',0.1);
textDataTrain = textData(training(cvp));
textDataValidation = textData(test(cvp));

Маркировка и предварительная обработка текстовых данных с помощью функции preprocessText который указан в конце этого примера.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Создайте модель пакета слов из учебных документов. Удалите слова, которые не появляются более двух раз. Удалите все документы, не содержащие слов.

bag = bagOfWords(documentsTrain);
bag = removeInfrequentWords(bag,2);
bag = removeEmptyDocuments(bag);

Для данных проверки создайте модель пакета слов из документов проверки. Удалять какие-либо слова из данных проверки не требуется, поскольку любые слова, которые не появляются в установленных моделях LDA, автоматически игнорируются.

validationData = bagOfWords(documentsValidation);

Подгонка и сравнение моделей

Для каждого из решателей LDA поместите модель с 40 темами. Для различения решателей при выводе результатов на печать на одних и тех же осях задайте различные свойства линий для каждого решателя.

numTopics = 40;
solvers = ["cgs" "avb" "cvb0" "savb"];
lineSpecs = ["+-" "*-" "x-" "o-"];

Поместите модель LDA с использованием каждого решателя. Для каждого решателя укажите начальную концентрацию тематики 1, чтобы проверить модель один раз за передачу данных и не соответствовать параметру концентрации тематики. Использование данных в FitInfo свойство подогнанных моделей LDA, постройте график недоумения проверки и прошедшего времени.

Стохастический решатель по умолчанию использует размер мини-пакета 1000 и проверяет модель каждые 10 итераций. Для этого решателя, чтобы проверить модель один раз за передачу данных, установите частоту проверки равной ceil(numObservations/1000), где numObservations - количество документов в данных обучения. Для других решателей установите частоту проверки в 1.

Для итераций, в которых стохастический решатель не оценивает недоумение проверки, стохастический решатель сообщает NaN в FitInfo собственность. Чтобы отобразить недоумение проверки, удалите NaNs из сообщаемых значений.

numObservations = bag.NumDocuments;

figure
for i = 1:numel(solvers)
    solver = solvers(i);
    lineSpec = lineSpecs(i);

    if solver == "savb"
        numIterationsPerDataPass = ceil(numObservations/1000);
    else
        numIterationsPerDataPass = 1;
    end

    mdl = fitlda(bag,numTopics, ...
        'Solver',solver, ...
        'InitialTopicConcentration',1, ...
        'FitTopicConcentration',false, ...
        'ValidationData',validationData, ...
        'ValidationFrequency',numIterationsPerDataPass, ...
        'Verbose',0);

    history = mdl.FitInfo.History;

    timeElapsed = history.TimeSinceStart;

    validationPerplexity = history.ValidationPerplexity;

    % Remove NaNs.
    idx = isnan(validationPerplexity);
    timeElapsed(idx) = [];
    validationPerplexity(idx) = [];

    plot(timeElapsed,validationPerplexity,lineSpec)
    hold on
end

hold off
xlabel("Time Elapsed (s)")
ylabel("Validation Perplexity")
ylim([0 inf])
legend(solvers)

Для стохастического решателя существует только одна точка данных. Это происходит потому, что этот решатель проходит через входные данные один раз. Чтобы указать больше проходов данных, используйте 'DataPassLimit' вариант. Для решателей партий ("cgs", "avb", и "cvb0"), чтобы указать количество итераций, используемых для соответствия моделям, используйте 'IterationLimit' вариант.

Более низкое недоумение проверки предполагает лучшую подгонку. Обычно решатели "savb" и "cgs" быстро сходятся до хорошей посадки. Решающее устройство "cvb0" может сойтись лучше, но может потребоваться гораздо больше времени, чтобы сойтись.

Для FitInfo свойство, fitlda функция оценивает недоумение проверки из вероятностей документа при максимальных оценках правдоподобия вероятностей тематики каждого документа. Это обычно быстрее вычисляется, но может быть менее точным, чем другие методы. Либо вычислите недоумение проверки с помощью logp функция. Эта функция вычисляет более точные значения, но может занять больше времени. Пример, показывающий, как вычислить недоумение с помощью logpсм. раздел Расчет вероятностей журнала документов из матрицы подсчета слов.

Функция предварительной обработки

Функция preprocessText выполняет следующие шаги:

  1. Маркировка текста с помощью tokenizedDocument.

  2. Лемматизировать слова с помощью normalizeWords.

  3. Стереть пунктуацию с помощью erasePunctuation.

  4. Удалите список стоп-слов (например, «and», «of» и «the»), используя removeStopWords.

  5. Удаление слов, содержащих не более 2 символов removeShortWords.

  6. Удаление слов с 15 или более символами с помощью removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Lemmatize the words.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Remove words with 2 or fewer characters, and words with 15 or greater
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);

end

См. также

| | | | | | | | | | | | |

Связанные темы