Многоуровневая классификация текста с использованием глубокого обучения

В этом примере показано, как классифицировать текстовые данные, которые имеют несколько независимых меток.

Для задач классификации, где для каждого наблюдения может быть несколько независимых меток - например, тегов в научной статье - можно обучить модель глубокого обучения для предсказания вероятностей для каждого независимого класса. Чтобы позволить сети обучаться многоуровневым целям классификации, можно оптимизировать потери каждого класса независимо с помощью двоичных потерь перекрестной энтропии.

Этот пример задает модель глубокого обучения, которая классифицирует предметные области, учитывая тезисы математических работ, собранных с использованием arXiv API [1]. Модель состоит из слов embedding и GRU, max операции объединения, полносвязных и сигмоидных операций.

Чтобы измерить эффективность классификации мультиметки, можно использовать маркировку F-балл [2]. Маркировка F-балла оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями. Мера является нормированной пропорцией совпадающих меток к общему количеству истинных и предсказанных меток.

Этот пример задает следующую модель:

  • Вложение слова, которое преобразует последовательность слов в последовательность числовых векторов.

  • Операция GRU, которая учит зависимости между векторами встраивания.

  • Операция максимального объединения, которая сокращает последовательность векторов функций до одного вектора функций.

  • Полносвязный слой, которая преобразует функции в двоичные выходные параметры.

  • Сигмоидная операция для изучения двоичной потери перекрестной энтропии между выходами и целевыми метками.

Эта схема показывает фрагмент текста, распространяющийся через архитектуру модели и выводящий вектор вероятностей. Вероятности независимы, поэтому они не должны суммироваться с единицей.

Импорт текстовых данных

Импортируйте набор тезисов и меток категорий из математических документов с помощью arXiv API. Укажите количество записей для импорта с помощью importSize переменная.

importSize = 50000;

Создайте URL-адрес, который запрашивает записи с установленными "math" и префикс метаданных "arXiv".

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

Извлечение абстрактного текста, меток категорий и лексемы возобновления, возвращенного URL-адресом запроса, с помощью parseArXivRecords функция, которая присоединена к этому примеру как вспомогательный файл. Чтобы получить доступ к этому файлу, откройте этот пример как live скрипт. Обратите внимание, что API arXiv ограничен по скорости и требует ожидания между несколькими запросами.

[textData,labelsAll,resumptionToken] = parseArXivRecords(url);

Итерационно импортируйте больше фрагменты записей до достижения необходимой суммы или больше нет записей. Чтобы продолжить импорт записей из того места, где вы остановились, используйте лексему возобновления из предыдущего результата в URL-адресе запроса. Чтобы соответствовать пределам скорости, установленным API arXiv, добавьте задержку в 20 секунд перед каждым запросом с помощью pause функция.

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
    labelsAll = [labelsAll; labelsNew];
end

Предварительная обработка текстовых данных

Токенизация и предварительная обработка текстовых данных с помощью preprocessText функции, перечисленной в конце примера.

documentsAll = preprocessText(textData);
documentsAll(1:5)
ans = 
  5×1 tokenizedDocument:

    72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concerning tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson
    22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant
    18 tokens: paper show compute $\lambda_{\alpha}$ norm alpha dyadic grid result consequence description hardy space $h^p(r^n)$ term dyadic special atom
    62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product paste expansion contraction process utilize paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate
    29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms

Удалите метки, которые не относятся к "math" установите.

for i = 1:numel(labelsAll)
    labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math."));
end

Визуализация некоторых классов в облаке слов. Найдите документы, соответствующие следующим:

  • Тезисы с тегами «Combinatorics» и не с тегами "Statistics Theory"

  • Тезисы с тегами «Statistics Theory» и не с тегами "Combinatorics"

  • Тезисы с обоими "Combinatorics" и "Statistics Theory"

Найдите индексы документов для каждой из групп, использующих ismember функция.

idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll);
idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll);
idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);

Визуализация документов для каждой группы в облаке слов.

figure
subplot(1,3,1)
wordcloud(documentsAll(idxCO));
title("Combinatorics")

subplot(1,3,2)
wordcloud(documentsAll(idxST));
title("Statistics Theory")

subplot(1,3,3)
wordcloud(documentsAll(idxCOST));
title("Both")

Просмотрите количество классов.

classNames = unique(cat(1,labelsAll{:}));
numClasses = numel(classNames)
numClasses = 32

Визуализируйте количество меток по документам с помощью гистограммы.

labelCounts = cellfun(@numel, labelsAll);
figure
histogram(labelCounts)
xlabel("Number of Labels")
ylabel("Frequency")
title("Label Counts")

Подготовка текстовых данных для глубокого обучения

Разделите данные на разделы для обучения и валидации с помощью cvpartition функция. Задержите 10% данных для валидации путем установки 'HoldOut' опция 0.1.

cvp = cvpartition(numel(documentsAll),'HoldOut',0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));

labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));

Создайте объект кодирования слов, который кодирует обучающие документы как последовательности индексов слов. Укажите словарь из 5000 слов путем установки 'Order' опция для 'frequency', и 'MaxNumWords' опция до 5000.

enc = wordEncoding(documentsTrain,'Order','frequency','MaxNumWords',5000)
enc = 
  wordEncoding with properties:

      NumWords: 5000
    Vocabulary: [1×5000 string]

Для улучшения обучения используйте следующие методы:

  1. При обучении обрезайте документы до длины, которая уменьшает количество используемого заполнения и не отбрасывает слишком много данных.

  2. Обучите на одну эпоху с документами, отсортированными по длине в порядке возрастания, затем перетащите данные каждую эпоху. Этот метод известен как сортаград.

Чтобы выбрать длину последовательности для усечения, визуализируйте длины документа в гистограмме и выберите значение, которое захватывает большую часть данных.

documentLengths = doclength(documentsTrain);

figure
histogram(documentLengths)
xlabel("Document Length")
ylabel("Frequency")
title("Document Lengths")

Большинство обучающих документов имеют менее 175 лексемы. Используйте 175 лексемы в качестве целевой длины для усечения и заполнения.

maxSequenceLength = 175;

Чтобы использовать метод сортировки, отсортируйте документы по длине в порядке возрастания.

[~,idx] = sort(documentLengths);
documentsTrain = documentsTrain(idx);
labelsTrain = labelsTrain(idx);

Определите и инициализируйте параметры модели

Определите параметры для каждой из операций и включите их в struct. Используйте формат parameters.OperationName.ParameterName, где parameters - struct, O perationName - имя операции (для примера "fc"), и ParameterName - имя параметра (для примера, "Weights").

Создайте struct parameters содержащие параметры модели. Инициализируйте смещение с нулями. Используйте для операций следующие инициализаторы веса:

  • Для встраивания инициализируйте веса со случайными нормальными значениями.

  • Для операции GRU инициализируйте веса с помощью initializeGlorot функции, перечисленной в конце примера.

  • Для операции полного подключения инициализируйте веса с помощью initializeGaussian функции, перечисленной в конце примера.

embeddingDimension = 300;
numHiddenUnits = 250;
inputSize = enc.NumWords + 1;

parameters = struct;
parameters.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parameters.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,embeddingDimension));
parameters.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parameters.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

parameters.fc.Weights = dlarray(initializeGaussian([numClasses,numHiddenUnits]));
parameters.fc.Bias = dlarray(zeros(numClasses,1,'single'));

Просмотрите parameters struct.

parameters
parameters = struct with fields:
    emb: [1×1 struct]
    gru: [1×1 struct]
     fc: [1×1 struct]

Просмотрите параметры для операции GRU.

parameters.gru
ans = struct with fields:
        InputWeights: [750×300 dlarray]
    RecurrentWeights: [750×250 dlarray]
                Bias: [750×1 dlarray]

Задайте функцию модели

Создайте функцию model, перечисленный в конце примера, который вычисляет выходы модели глубокого обучения, описанной ранее. Функция model принимает за вход входные данные dlX и параметры модели parameters. Сеть выводит предсказания для меток.

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в конце примера, который принимает за вход мини-пакет входных данных dlX и соответствующие цели T содержит метки и возвращает градиенты потерь относительно настраиваемых параметров, соответствующих потерь и выходов сети.

Настройка опций обучения

Train на 5 эпох с мини-партией размером 256.

numEpochs = 5;
miniBatchSize = 256;

Обучите с помощью оптимизатора Адама со скоростью обучения 0,01 и задайте коэффициенты градиента и квадратного градиента распада 0,5 и 0,999 соответственно.

learnRate = 0.01;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Обрезка градиентов с порогом 1 с помощью L2 норма усечения градиента.

gradientThreshold = 1;

Визуализируйте процесс обучения на графике.

plots = "training-progress";

Чтобы преобразовать вектор вероятностей в метки, используйте метки с вероятностями, превышающими заданный порог. Задайте порог метки 0,5.

labelThreshold = 0.5;

Проверяйте сеть каждую эпоху.

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
validationFrequency = numIterationsPerEpoch;

Обучите на графическом процессоре, если он доступен. Для этого требуется Parallel Computing Toolbox™. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах см. раздел.

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского цикла обучения.

Для каждой эпохи закольцовывайте мини-пакеты данных. В конце каждой эпохи перетасовывайте данные. В конце каждой итерации обновляйте график процесса обучения.

Для каждого мини-пакета:

  • Преобразуйте документы в последовательности индексов слов и преобразуйте метки в фиктивные переменные.

  • Преобразуйте последовательности в dlarray объекты с базовым типом одинарные и задают метки размерностей 'BCT' (пакет, канал, время).

  • Для обучения графический процессор преобразуйте в gpuArray объекты.

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функция.

  • Обрезка градиентов.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • При необходимости проверьте сеть с помощью modelPredictions функции, перечисленной в конце примера.

  • Обновите график обучения.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    
    % Labeling F-Score.
    subplot(2,1,1)
    lineFScoreTrain = animatedline('Color',[0 0.447 0.741]);
    lineFScoreValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Iteration")
    ylabel("Labeling F-Score")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте параметры для оптимизатора Адама.

trailingAvg = [];
trailingAvgSq = [];

Подготовьте данные валидации. Создайте матрицу с одним горячим кодированием, где ненулевые значения соответствуют меткам каждого наблюдения.

numObservationsValidation = numel(documentsValidation);
TValidation = zeros(numClasses, numObservationsValidation, 'single');
for i = 1:numObservationsValidation
    [~,idx] = ismember(labelsValidation{i},classNames);
    TValidation(idx,i) = 1;
end

Обучите модель.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        documents = documentsTrain(idx);
        labels = labelsTrain(idx);
        
        % Convert documents to sequences.
        len = min(maxSequenceLength,max(doclength(documents)));
        X = doc2sequence(enc,documents, ...
            'PaddingValue',inputSize, ...
            'Length',len);
        X = cat(1,X{:});
        
        % Dummify labels.
        T = zeros(numClasses, miniBatchSize, 'single');
        for j = 1:miniBatchSize
            [~,idx2] = ismember(labels{j},classNames);
            T(idx2,j) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X,'BTC');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss,dlYPred] = dlfeval(@modelGradients, dlX, T, parameters);
        
        % Gradient clipping.
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            subplot(2,1,1)
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            % Loss.
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            
            % Labeling F-score.
            YPred = extractdata(dlYPred) > labelThreshold;
            score = labelingFScore(YPred,T);
            addpoints(lineFScoreTrain,iteration,double(gather(score)))
            
            drawnow
            
            % Display validation metrics.
            if iteration == 1 || mod(iteration,validationFrequency) == 0
                dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);
                
                % Loss.
                lossValidation = crossentropy(dlYPredValidation,TValidation, ...
                    'TargetCategories','independent', ...
                    'DataFormat','CB');
                addpoints(lineLossValidation,iteration,double(gather(extractdata(lossValidation))))
                
                % Labeling F-score.
                YPredValidation = extractdata(dlYPredValidation) > labelThreshold;
                score = labelingFScore(YPredValidation,TValidation);
                addpoints(lineFScoreValidation,iteration,double(gather(score)))
                
                drawnow
            end
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservationsTrain);
    documentsTrain = documentsTrain(idx);
    labelsTrain = labelsTrain(idx);
end

Экспериментальная модель

Чтобы делать предсказания на новом наборе данных, используйте modelPredictions функции, перечисленной в конце примера. The modelPredictions функция принимает за вход параметры модели, кодировку слов и массив токенизированных документов и выводит предсказания модели, соответствующие заданному размеру мини-пакета и максимальной длине последовательности.

dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

Чтобы преобразовать выходы сети в массив меток, найдите метки с счетами выше заданного порога метки.

YPredValidation = extractdata(dlYPredValidation) > labelThreshold;

Чтобы оценить эффективность, вычислите счет F маркировки с помощью labelingFScore функции, перечисленной в конце примера. Маркировка F-балла оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями.

score = labelingFScore(YPredValidation,TValidation)
score = single
    0.5663

Просмотрите эффект порога маркировки на счет F маркировки, попробовав область значений значений для порога и сравнив результаты.

thr = linspace(0,1,10);
score = zeros(size(thr));
for i = 1:numel(thr)
    YPredValidationThr = extractdata(dlYPredValidation) >= thr(i);
    score(i) = labelingFScore(YPredValidationThr,TValidation);
end

figure
plot(thr,score)
xline(labelThreshold,'r--');
xlabel("Threshold")
ylabel("Labeling F-Score")
title("Effect of Labeling Threshold")

Визуализация предсказаний

Чтобы визуализировать правильные предсказания классификатора, вычислите количество истинных срабатываний. Истинный позитив является образцом классификатора, правильно предсказывающим конкретный класс для наблюдения.

Y = YPredValidation;
T = TValidation;

numTruePositives = sum(T & Y,2);

numObservationsPerClass = sum(T,2);
truePositiveRates = numTruePositives ./ numObservationsPerClass;

Визуализируйте числа истинных положительных результатов для каждого класса в гистограмме.

figure
[~,idx] = sort(truePositiveRates,'descend');
histogram('Categories',classNames(idx),'BinCounts',truePositiveRates(idx))
xlabel("Category")
ylabel("True Positive Rate")
title("True Positive Rates")

Визуализируйте образцы, где классификатор неправильно предсказывает, показав распределение истинных срабатываний, ложных срабатываний и ложных срабатываний. Ложное положительное - это образец классификатора, присваивающий наблюдению конкретный неправильный класс. Ложный минус - это образец классификатора, не присваивающий наблюдению конкретный правильный класс.

Создайте матрицу неточностей, показывающую истинные положительные, ложноположительные и ложноотрицательные отсчеты:

  • Для каждого класса отобразите истинные положительные счетчики на диагонали.

  • Для каждой пары классов (i, j) отображает количество образцов ложного положительного для j, когда образец также является ложным отрицательным для i.

То есть матрица неточностей с элементами, заданными:

TPFNij={numTruePositives(i),if i=jnumFalsePositives(j|i является  ложноотрицательным ),if ijИстинный положительный, ложноотрицательные скорости

Вычислите ложные срабатывания и ложные срабатывания.

falseNegatives = T & ~Y;
falsePositives = ~T & Y;

Вычислите недиагональные элементы.

falseNegatives = permute(falseNegatives,[3 2 1]);
numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2);
numConditionalFalsePositives = squeeze(numConditionalFalsePositives);

tpfnMatrix = numConditionalFalsePositives;

Установите диагональные элементы на истинные положительные счетчики.

idxDiagonal = 1:numClasses+1:numClasses^2;
tpfnMatrix(idxDiagonal) = numTruePositives;

Визуализируйте истинные положительные и ложноположительные счетчики в матрице неточностей с помощью confusionchart и отсортируйте матрицу так, чтобы элементы диагонали были в порядке убывания.

figure
cm = confusionchart(tpfnMatrix,classNames);
sortClasses(cm,"descending-diagonal");
title("True Positives, False Positives")

Чтобы просмотреть матрицу более подробно, откройте этот пример как live скрипт и откройте рисунок в новом окне.

Функция предварительной обработки текста

The preprocessText функция токенизирует и предварительно обрабатывает входные текстовые данные с помощью следующих шагов:

  1. Токенизация текста с помощью tokenizedDocument функция. Извлечь математические уравнения как одна лексема с помощью 'RegularExpressions' опция путем определения регулярного выражения "\$.*?\$", который захватывает текст, появляющийся между двумя символами «$».

  2. Удалите пунктуацию с помощью erasePunctuation функция.

  3. Преобразуйте текст в нижний регистр с помощью lower функция.

  4. Удалите стоповые слова с помощью removeStopWords функция.

  5. Лемматизируйте текст с помощью normalizeWords функция со 'Style' значение опции установлено в 'lemma'.

function documents = preprocessText(textData)

% Tokenize the text.
regularExpressions = table;
regularExpressions.Pattern = "\$.*?\$";
regularExpressions.Type = "equation";

documents = tokenizedDocument(textData,'RegularExpressions',regularExpressions);

% Erase punctuation.
documents = erasePunctuation(documents);

% Convert to lowercase.
documents = lower(documents);

% Lemmatize.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','Lemma');

% Remove stop words.
documents = removeStopWords(documents);

% Remove short words.
documents = removeShortWords(documents,2);

end

Моделируйте функцию

Функция model принимает за вход входные данные dlX и параметры модели parameters, и возвращает предсказания для меток.

function dlY = model(dlX,parameters)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% GRU
inputWeights = parameters.gru.InputWeights;
recurrentWeights = parameters.gru.RecurrentWeights;
bias = parameters.gru.Bias;

numHiddenUnits = size(inputWeights,1)/3;
hiddenState = dlarray(zeros([numHiddenUnits 1]));

dlY = gru(dlX, hiddenState, inputWeights, recurrentWeights, bias,'DataFormat','CBT');

% Max pooling along time dimension
dlY = max(dlY,[],3);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Sigmoid
dlY = sigmoid(dlY);

end

Функция градиентов модели

The modelGradients функция принимает за вход мини-пакет входных данных dlX с соответствующими целями T содержит метки и возвращает градиенты потерь относительно настраиваемых параметров, соответствующих потерь и выходов сети.

function [gradients,loss,dlYPred] = modelGradients(dlX,T,parameters)

dlYPred = model(dlX,parameters);

loss = crossentropy(dlYPred,T,'TargetCategories','independent','DataFormat','CB');

gradients = dlgradient(loss,parameters);

end

Функция предсказаний модели

The modelPredictions функция принимает за вход параметры модели, кодировку слов, массив токенизированных документов, размер мини-пакета и максимальную длину последовательности и возвращает предсказания модели путем итерации по мини-пакетам заданного размера.

function dlYPred = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength)

inputSize = enc.NumWords + 1;

numObservations = numel(documents);
numIterations = ceil(numObservations / miniBatchSize);

numFeatures = size(parameters.fc.Weights,1);
dlYPred = zeros(numFeatures,numObservations,'like',parameters.fc.Weights);

for i = 1:numIterations
    
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    len = min(maxSequenceLength,max(doclength(documents(idx))));
    X = doc2sequence(enc,documents(idx), ...
        'PaddingValue',inputSize, ...
        'Length',len);
    X = cat(1,X{:});
    
    dlX = dlarray(X,'BTC');
    
    dlYPred(:,idx) = model(dlX,parameters);
end

end

Маркировка функции F-Score

Функция [2] маркировки F-балла оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями. Мера является нормированной пропорцией совпадающих меток к общему количеству истинных и предсказанных меток, заданных как

1Nn=1N(2c=1CYncTncc=1C(Ync+Tnc)),Маркировка F-балла

где N и C соответствуют количеству наблюдений и классов, соответственно, и Y и T соответствуют предсказаниям и целям, соответственно.

function score = labelingFScore(Y,T)

numObservations = size(T,2);

scores = (2 * sum(Y .* T)) ./ sum(Y + T);
score = sum(scores) / numObservations;

end

Функция инициализации весов Глорота

The initializeGlorot функция генерирует массив весов согласно инициализации Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

Функция инициализации Гауссовых весов

The initializeGaussian функции отбирают веса из Гауссова распределения со средним 0 и стандартным отклонением 0,01.

function parameter = initializeGaussian(sz)

parameter = randn(sz,'single') .* 0.01;

end

Функция встраивания

The embedding функция преобразует числовые индексы в соответствующий вектор, заданный входными весами.

function Z = embedding(X, weights)
% Reshape inputs into a vector.
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix.
Z = weights(:, X);

% Reshape outputs by separating batch and sequence dimensions.
Z = reshape(Z, [], N, T);
end

L2 Нормальная функция Усечения градиента

The thresholdL2Norm функция масштабирует входные градиенты так, чтобы их L2 норма равна заданному порогу градиента, когда L2 норма значения градиента настраиваемого параметра больше заданного порога.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end

Ссылки

  1. arXiv. «arXiv API». Доступ к 15 января 2020 года. https://arxiv.org/help/api

  2. Соколова, Марина и Гай Лапальме. «Синтематический анализ показателей эффективности для задач классификации». Обработка информации и управление 45, № 4 (2009): 427-437.

См. также

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте