Сравнение решателей LDA

Этот пример показывает, как сравнить латентные решатели распределения Дирихле (LDA), сравнивая качество подгонки и время, необходимое для подгонки модели.

Импорт текстовых данных

Импортируйте набор тезисов и меток категорий из математических документов с помощью arXiv API. Укажите количество записей для импорта с помощью importSize переменная.

importSize = 50000;

Создайте URL-адрес, который запрашивает записи с установленными "math" и префикс метаданных "arXiv".

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

Извлечение абстрактного текста и лексемы возобновления, возвращенных URL-адресом запроса, с помощью parseArXivRecords функция, которая присоединена к этому примеру как вспомогательный файл. Чтобы получить доступ к этому файлу, откройте этот пример как live скрипт. Обратите внимание, что API arXiv ограничен по скорости и требует ожидания между несколькими запросами.

[textData,~,resumptionToken] = parseArXivRecords(url);

Итерационно импортируйте больше фрагменты записей до достижения необходимой суммы или больше нет записей. Чтобы продолжить импорт записей из того места, где вы остановились, используйте лексему возобновления из предыдущего результата в URL-адресе запроса. Чтобы соответствовать пределам скорости, установленным API arXiv, добавьте задержку в 20 секунд перед каждым запросом с помощью pause функция.

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
end

Предварительная обработка текстовых данных

Выделите 10% документов случайным образом для валидации.

numDocuments = numel(textData);
cvp = cvpartition(numDocuments,'HoldOut',0.1);
textDataTrain = textData(training(cvp));
textDataValidation = textData(test(cvp));

Токенизация и предварительная обработка текстовых данных с помощью функции preprocessText который перечислен в конце этого примера.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Создайте модель мешка слов из обучающих документов. Удалите слова, которые не появляются более двух раз в общей сложности. Удалите все документы, не содержащие слов.

bag = bagOfWords(documentsTrain);
bag = removeInfrequentWords(bag,2);
bag = removeEmptyDocuments(bag);

Для данных валидации создайте модель мешка слов из документов валидации. Вам не нужно удалять какие-либо слова из данных валидаитона, потому что любые слова, которые не появляются в подобранных моделях LDA, автоматически игнорируются.

validationData = bagOfWords(documentsValidation);

Подгонка и сравнение моделей

Для каждого из решателей LDA подбирайте модель с 40 темами. Чтобы различать решатели при построении графиков результатов на тех же осях, задайте различные свойства линий для каждого решателя.

numTopics = 40;
solvers = ["cgs" "avb" "cvb0" "savb"];
lineSpecs = ["+-" "*-" "x-" "o-"];

Подгонка модели LDA с помощью каждого решателя. Для каждого решателя укажите начальную концентрацию темы 1, чтобы подтвердить модель один раз за проход данных и не соответствовать параметру концентрации темы. Использование данных в FitInfo свойство подобранных моделей LDA, постройте график недоумения валидации и истекшего времени.

Решатель стохастики по умолчанию использует мини-пакет размером 1000 и проверяет модель каждые 10 итераций. Для этого решателя, чтобы подтвердить модель один раз на проход данных, установите частоту валидации равной ceil(numObservations/1000), где numObservations количество документов в обучающих данных. Для других решателей установите частоту валидации равной 1.

Для итераций, которые стохастический решатель не оценивает недоумение валидации, стохастический решатель сообщает NaN в FitInfo свойство. Чтобы построить график недоумения валидации, удалите NaNs из сообщенных значений.

numObservations = bag.NumDocuments;

figure
for i = 1:numel(solvers)
    solver = solvers(i);
    lineSpec = lineSpecs(i);

    if solver == "savb"
        numIterationsPerDataPass = ceil(numObservations/1000);
    else
        numIterationsPerDataPass = 1;
    end

    mdl = fitlda(bag,numTopics, ...
        'Solver',solver, ...
        'InitialTopicConcentration',1, ...
        'FitTopicConcentration',false, ...
        'ValidationData',validationData, ...
        'ValidationFrequency',numIterationsPerDataPass, ...
        'Verbose',0);

    history = mdl.FitInfo.History;

    timeElapsed = history.TimeSinceStart;

    validationPerplexity = history.ValidationPerplexity;

    % Remove NaNs.
    idx = isnan(validationPerplexity);
    timeElapsed(idx) = [];
    validationPerplexity(idx) = [];

    plot(timeElapsed,validationPerplexity,lineSpec)
    hold on
end

hold off
xlabel("Time Elapsed (s)")
ylabel("Validation Perplexity")
ylim([0 inf])
legend(solvers)

Для стохастического решателя существует только одна точка данных. Это потому, что этот решатель проходит через входные данные один раз. Чтобы задать больше проходов данных, используйте 'DataPassLimit' опция. Для решателей пакета ("cgs", "avb", и "cvb0"), чтобы указать количество итераций, используемых для подгонки моделей, используйте 'IterationLimit' опция.

Меньшее недоумение валидации предполагает лучшую подгонку. Обычно решатели "savb" и "cgs" быстро сходиться к хорошей подгонке. Решатель "cvb0" может сходиться к лучшей подгонке, но может потребоваться гораздо больше времени, чтобы сойтись.

Для FitInfo свойство, fitlda функция оценивает недоумение валидации из вероятностей документа в максимальных оценках правдоподобия вероятностей темы по документу. Обычно это быстрее вычислить, но может быть менее точным, чем другие методы. Кроме того, вычислите недоумение валидации, используя logp функция. Эта функция вычисляет более точные значения, но может занять больше времени. Для примера, показывающего, как вычислить недоумение с помощью logp, см. «Вычисление логарифмических вероятностей документа из матрицы count слов».

Функция предварительной обработки

Функция preprocessText выполняет следующие шаги:

  1. Токенизация текста с помощью tokenizedDocument.

  2. Лемматизируйте слова, используя normalizeWords.

  3. Удалите пунктуацию с помощью erasePunctuation.

  4. Удалите список стоповых слов (таких как «and», «of», и «the») с помощью removeStopWords.

  5. Удалите слова с 2 или меньшим количеством символов, используя removeShortWords.

  6. Удалите слова с 15 или более символами, используя removeLongWords.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Lemmatize the words.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','lemma');

% Erase punctuation.
documents = erasePunctuation(documents);

% Remove a list of stop words.
documents = removeStopWords(documents);

% Remove words with 2 or fewer characters, and words with 15 or greater
% characters.
documents = removeShortWords(documents,2);
documents = removeLongWords(documents,15);

end

См. также

| | | | | | | | | | | | |

Похожие темы