exponenta event banner

Классификация текста Multilabel с использованием глубокого обучения

В этом примере показано, как классифицировать текстовые данные с несколькими независимыми метками.

Для задач классификации, где может быть несколько независимых меток для каждого наблюдения - например, теги в научной статье - можно обучить модель глубокого обучения для прогнозирования вероятностей для каждого независимого класса. Чтобы сеть могла изучать цели многоуровневой классификации, можно оптимизировать потери каждого класса независимо, используя двоичные потери перекрестной энтропии.

Этот пример определяет модель глубокого обучения, которая классифицирует предметные области, учитывая тезисы математических работ, собранных с помощью arXiv API [1]. Модель состоит из встраивания слов и GRU, операции максимального объединения, полностью соединенных и сигмоидных операций.

Для измерения производительности многоуровневой классификации можно использовать маркировку F-score [2]. Оценка F оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями. Мера представляет собой нормализованную долю совпадающих меток от общего числа истинных и прогнозируемых меток.

В этом примере определяется следующая модель:

  • Вложение слова, которое отображает последовательность слов в последовательность числовых векторов.

  • Операция GRU, которая распознает зависимости между векторами внедрения.

  • Операция максимального объединения, которая уменьшает последовательность векторов признаков до одного вектора признаков.

  • Полностью подключенный слой, который сопоставляет элементы с двоичными выходами.

  • Сигмоидальная операция для изучения двоичных потерь перекрестной энтропии между выходами и целевыми метками.

На этой диаграмме показан фрагмент текста, распространяющийся через архитектуру модели и выводящий вектор вероятностей. Вероятности независимы, поэтому их не нужно суммировать.

Импорт текстовых данных

Импорт набора аннотаций и меток категорий из математических документов с помощью API arXiv. Укажите количество записей для импорта с помощью importSize переменная.

importSize = 50000;

Создание URL-адреса для запроса записей с набором "math" и префикс метаданных "arXiv".

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

Извлеките абстрактный текст, метки категорий и маркер возобновления, возвращенный URL-адресом запроса, с помощью parseArXivRecords функция, которая присоединена к этому примеру в качестве вспомогательного файла. Чтобы получить доступ к этому файлу, откройте этот пример в реальном времени. Следует отметить, что API arXiv ограничен по скорости и требует ожидания между несколькими запросами.

[textData,labelsAll,resumptionToken] = parseArXivRecords(url);

Итеративно импортируйте больше порций записей до тех пор, пока не будет достигнута требуемая сумма или не будет больше записей. Чтобы продолжить импорт записей из того места, где вы остановились, используйте маркер возобновления из предыдущего результата в URL-адресе запроса. Чтобы придерживаться ограничений скорости, установленных API arXiv, добавьте задержку в 20 секунд перед каждым запросом, используя pause функция.

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
    labelsAll = [labelsAll; labelsNew];
end

Предварительная обработка текстовых данных

Маркировка и предварительная обработка текстовых данных с помощью preprocessText функция, перечисленная в конце примера.

documentsAll = preprocessText(textData);
documentsAll(1:5)
ans = 
  5×1 tokenizedDocument:

    72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concerning tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson
    22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant
    18 tokens: paper show compute $\lambda_{\alpha}$ norm alpha dyadic grid result consequence description hardy space $h^p(r^n)$ term dyadic special atom
    62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product paste expansion contraction process utilize paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate
    29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms

Удалить метки, не принадлежащие "math" набор.

for i = 1:numel(labelsAll)
    labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math."));
end

Визуализация некоторых классов в облаке слов. Найдите документы, соответствующие:

  • Тезисы, помеченные «Combinatorics» и не помеченные "Statistics Theory"

  • Тезисы, помеченные как «Теория статистики» и не помеченные как «Теория статистики» "Combinatorics"

  • Тезисы, помеченные обоими тегами "Combinatorics" и "Statistics Theory"

Поиск индексов документов для каждой из групп с помощью ismember функция.

idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll);
idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll);
idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);

Визуализация документов для каждой группы в облаке слов.

figure
subplot(1,3,1)
wordcloud(documentsAll(idxCO));
title("Combinatorics")

subplot(1,3,2)
wordcloud(documentsAll(idxST));
title("Statistics Theory")

subplot(1,3,3)
wordcloud(documentsAll(idxCOST));
title("Both")

Просмотр количества классов.

classNames = unique(cat(1,labelsAll{:}));
numClasses = numel(classNames)
numClasses = 32

Визуализация количества меток для каждого документа с помощью гистограммы.

labelCounts = cellfun(@numel, labelsAll);
figure
histogram(labelCounts)
xlabel("Number of Labels")
ylabel("Frequency")
title("Label Counts")

Подготовка текстовых данных для глубокого обучения

Разбиение данных на разделы обучения и проверки с помощью cvpartition функция. Удерживайте 10% данных для проверки, установив 'HoldOut' опция 0,1.

cvp = cvpartition(numel(documentsAll),'HoldOut',0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));

labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));

Создайте объект кодирования слов, который кодирует учебные документы как последовательности индексов слов. Укажите словарный запас 5000 слов, установив 'Order' опция для 'frequency', и 'MaxNumWords' опция 5000.

enc = wordEncoding(documentsTrain,'Order','frequency','MaxNumWords',5000)
enc = 
  wordEncoding with properties:

      NumWords: 5000
    Vocabulary: [1×5000 string]

Для улучшения обучения используйте следующие методы:

  1. При обучении усечение документов до такой длины, которая уменьшает количество используемых дополнений и не отбрасывает слишком много данных.

  2. Тренируйтесь в течение одной эпохи с документами, отсортированными по длине в порядке возрастания, затем тасуйте данные каждую эпоху. Эта техника известна как сортаград.

Чтобы выбрать длину последовательности для усечения, визуализируйте длины документа в гистограмме и выберите значение, которое захватывает большую часть данных.

documentLengths = doclength(documentsTrain);

figure
histogram(documentLengths)
xlabel("Document Length")
ylabel("Frequency")
title("Document Lengths")

Большинство учебных документов имеют менее 175 маркеров. Используйте 175 токенов в качестве целевой длины для усечения и заполнения.

maxSequenceLength = 175;

Для использования метода sortagrad сортируйте документы по длине в порядке возрастания.

[~,idx] = sort(documentLengths);
documentsTrain = documentsTrain(idx);
labelsTrain = labelsTrain(idx);

Определение и инициализация параметров модели

Определите параметры для каждой операции и включите их в структуру. Использовать формат parameters.OperationName.ParameterName, где parameters - структура, OperationName - имя операции (например, "fc"), и ParameterName - имя параметра (например, "Weights").

Создание структуры parameters содержащий параметры модели. Инициализируйте смещение нулями. Используйте следующие инициализаторы веса для операций:

  • Для внедрения инициализируйте веса случайными нормальными значениями.

  • Для операции GRU инициализируйте веса с помощью initializeGlorot функция, перечисленная в конце примера.

  • Для операции полного подключения инициализируйте веса с помощью initializeGaussian функция, перечисленная в конце примера.

embeddingDimension = 300;
numHiddenUnits = 250;
inputSize = enc.NumWords + 1;

parameters = struct;
parameters.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parameters.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,embeddingDimension));
parameters.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parameters.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

parameters.fc.Weights = dlarray(initializeGaussian([numClasses,numHiddenUnits]));
parameters.fc.Bias = dlarray(zeros(numClasses,1,'single'));

Просмотр parameters структура.

parameters
parameters = struct with fields:
    emb: [1×1 struct]
    gru: [1×1 struct]
     fc: [1×1 struct]

Просмотрите параметры для операции GRU.

parameters.gru
ans = struct with fields:
        InputWeights: [750×300 dlarray]
    RecurrentWeights: [750×250 dlarray]
                Bias: [750×1 dlarray]

Определение функции модели

Создание функции model, перечисленных в конце примера, который вычисляет выходные данные модели глубокого обучения, описанной ранее. Функция model принимает в качестве входных входные данные dlX и параметры модели parameters. Сеть выводит прогнозы для меток.

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в конце примера, который принимает в качестве входных данных мини-пакет входных данных dlX и соответствующие целевые показатели T содержит метки и возвращает градиенты потерь относительно обучаемых параметров, соответствующих потерь и сетевых выходов.

Укажите параметры обучения

Поезд на 5 эпох с размером мини-партии 256.

numEpochs = 5;
miniBatchSize = 256;

Используйте оптимизатор Адама со скоростью обучения 0,01 и задайте коэффициенты градиентного распада и квадрата градиентного распада 0,5 и 0,999 соответственно.

learnRate = 0.01;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Подрезка градиентов с порогом 1 с помощью отсечения градиента L2 норме.

gradientThreshold = 1;

Визуализация хода обучения на графике.

plots = "training-progress";

Чтобы преобразовать вектор вероятностей в метки, используйте метки с вероятностями выше заданного порога. Укажите порог метки 0,5.

labelThreshold = 0.5;

Проверяйте сеть каждую эпоху.

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
validationFrequency = numIterationsPerEpoch;

Обучение на GPU, если он доступен. Для этого требуется Toolbox™ параллельных вычислений. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе.

executionEnvironment = "auto";

Модель поезда

Обучение модели с помощью пользовательского цикла обучения.

Для каждой эпохи закольцовывайте мини-пакеты данных. В конце каждой эпохи перетасуйте данные. В конце каждой итерации обновите график хода обучения.

Для каждой мини-партии:

  • Преобразование документов в последовательности индексов слов и преобразование меток в фиктивные переменные.

  • Преобразование последовательностей в dlarray объекты с одним базовым типом и указать метки размеров 'BCT' (партия, канал, время).

  • Для обучения GPU, конвертировать в gpuArray объекты.

  • Оценка градиентов и потерь модели с помощью dlfeval и modelGradients функция.

  • Подрезать градиенты.

  • Обновление параметров сети с помощью adamupdate функция.

  • При необходимости проверьте сеть с помощью modelPredictions функция, перечисленная в конце примера.

  • Обновите график обучения.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    
    % Labeling F-Score.
    subplot(2,1,1)
    lineFScoreTrain = animatedline('Color',[0 0.447 0.741]);
    lineFScoreValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Iteration")
    ylabel("Labeling F-Score")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализация параметров оптимизатора Adam.

trailingAvg = [];
trailingAvgSq = [];

Подготовьте данные проверки. Создайте однокодированную закодированную матрицу, в которой ненулевые записи соответствуют меткам каждого наблюдения.

numObservationsValidation = numel(documentsValidation);
TValidation = zeros(numClasses, numObservationsValidation, 'single');
for i = 1:numObservationsValidation
    [~,idx] = ismember(labelsValidation{i},classNames);
    TValidation(idx,i) = 1;
end

Тренируйте модель.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        documents = documentsTrain(idx);
        labels = labelsTrain(idx);
        
        % Convert documents to sequences.
        len = min(maxSequenceLength,max(doclength(documents)));
        X = doc2sequence(enc,documents, ...
            'PaddingValue',inputSize, ...
            'Length',len);
        X = cat(1,X{:});
        
        % Dummify labels.
        T = zeros(numClasses, miniBatchSize, 'single');
        for j = 1:miniBatchSize
            [~,idx2] = ismember(labels{j},classNames);
            T(idx2,j) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X,'BTC');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss,dlYPred] = dlfeval(@modelGradients, dlX, T, parameters);
        
        % Gradient clipping.
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            subplot(2,1,1)
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            % Loss.
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            
            % Labeling F-score.
            YPred = extractdata(dlYPred) > labelThreshold;
            score = labelingFScore(YPred,T);
            addpoints(lineFScoreTrain,iteration,double(gather(score)))
            
            drawnow
            
            % Display validation metrics.
            if iteration == 1 || mod(iteration,validationFrequency) == 0
                dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);
                
                % Loss.
                lossValidation = crossentropy(dlYPredValidation,TValidation, ...
                    'TargetCategories','independent', ...
                    'DataFormat','CB');
                addpoints(lineLossValidation,iteration,double(gather(extractdata(lossValidation))))
                
                % Labeling F-score.
                YPredValidation = extractdata(dlYPredValidation) > labelThreshold;
                score = labelingFScore(YPredValidation,TValidation);
                addpoints(lineFScoreValidation,iteration,double(gather(score)))
                
                drawnow
            end
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservationsTrain);
    documentsTrain = documentsTrain(idx);
    labelsTrain = labelsTrain(idx);
end

Тестовая модель

Для прогнозирования нового набора данных используйте modelPredictions функция, перечисленная в конце примера. modelPredictions функция принимает в качестве входных параметров модели, кодирование слов и массив маркированных документов и выводит предсказания модели, соответствующие заданному размеру мини-пакета и максимальной длине последовательности.

dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

Чтобы преобразовать сетевые выходы в массив меток, найдите метки с баллами выше указанного порога.

YPredValidation = extractdata(dlYPredValidation) > labelThreshold;

Чтобы оценить производительность, вычислите F-балл маркировки с помощью labelingFScore функция, перечисленная в конце примера. Оценка F оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями.

score = labelingFScore(YPredValidation,TValidation)
score = single
    0.5663

Просмотрите влияние порога маркировки на F-балл маркировки, попробовав диапазон значений для порога и сравнив результаты.

thr = linspace(0,1,10);
score = zeros(size(thr));
for i = 1:numel(thr)
    YPredValidationThr = extractdata(dlYPredValidation) >= thr(i);
    score(i) = labelingFScore(YPredValidationThr,TValidation);
end

figure
plot(thr,score)
xline(labelThreshold,'r--');
xlabel("Threshold")
ylabel("Labeling F-Score")
title("Effect of Labeling Threshold")

Визуализация прогнозов

Чтобы визуализировать правильные прогнозы классификатора, вычислите количество истинных положительных результатов. Истинно положительным является экземпляр классификатора, правильно предсказывающий конкретный класс для наблюдения.

Y = YPredValidation;
T = TValidation;

numTruePositives = sum(T & Y,2);

numObservationsPerClass = sum(T,2);
truePositiveRates = numTruePositives ./ numObservationsPerClass;

Визуализация количества истинных положительных результатов для каждого класса в гистограмме.

figure
[~,idx] = sort(truePositiveRates,'descend');
histogram('Categories',classNames(idx),'BinCounts',truePositiveRates(idx))
xlabel("Category")
ylabel("True Positive Rate")
title("True Positive Rates")

Визуализируйте экземпляры, где классификатор предсказывает неверно, показывая распределение истинных срабатываний, ложных срабатываний и ложных отрицаний. Ложноположительный - это экземпляр классификатора, присваивающий определенный неверный класс наблюдению. Ложноотрицательный - это экземпляр классификатора, которому не удалось назначить определенный правильный класс для наблюдения.

Создайте матрицу путаницы, показывающую истинные положительные, ложноположительные и ложноотрицательные счетчики:

  • Для каждого класса отображать истинные положительные подсчеты по диагонали.

  • Для каждой пары классов (i, j) отображает количество экземпляров ложноположительного для j, когда экземпляр также является ложноотрицательным для i.

То есть матрица путаницы с элементами, заданными:

TPFNij = {numStartPositives (i) , если i =  jnumPositionPositives ( j 'i - ложноотрицательный),  если i≠jTrue положительный, ложноотрицательный уровень

Вычислите ложные негативы и ложные положительные результаты.

falseNegatives = T & ~Y;
falsePositives = ~T & Y;

Вычислите внедиагональные элементы.

falseNegatives = permute(falseNegatives,[3 2 1]);
numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2);
numConditionalFalsePositives = squeeze(numConditionalFalsePositives);

tpfnMatrix = numConditionalFalsePositives;

Установите для диагональных элементов истинные положительные значения.

idxDiagonal = 1:numClasses+1:numClasses^2;
tpfnMatrix(idxDiagonal) = numTruePositives;

Визуализация истинных положительных и ложных положительных счетчиков в матрице путаницы с помощью confusionchart и сортировать матрицу так, чтобы элементы на диагонали были в порядке убывания.

figure
cm = confusionchart(tpfnMatrix,classNames);
sortClasses(cm,"descending-diagonal");
title("True Positives, False Positives")

Для более подробного просмотра матрицы откройте этот пример в виде сценария в реальном времени и откройте рисунок в новом окне.

Предварительная обработка текстовой функции

preprocessText выполняет токенизацию и предварительную обработку входных текстовых данных с помощью следующих шагов:

  1. Маркировка текста с помощью tokenizedDocument функция. Извлеките математические уравнения как один маркер с помощью 'RegularExpressions' путем указания регулярного выражения "\$.*?\$", который фиксирует текст, появляющийся между двумя символами «$».

  2. Стереть пунктуацию с помощью erasePunctuation функция.

  3. Преобразование текста в нижний регистр с помощью lower функция.

  4. Удалите стоп-слова с помощью removeStopWords функция.

  5. Лемматизировать текст с помощью normalizeWords функции с помощью 'Style' параметр имеет значение 'lemma'.

function documents = preprocessText(textData)

% Tokenize the text.
regularExpressions = table;
regularExpressions.Pattern = "\$.*?\$";
regularExpressions.Type = "equation";

documents = tokenizedDocument(textData,'RegularExpressions',regularExpressions);

% Erase punctuation.
documents = erasePunctuation(documents);

% Convert to lowercase.
documents = lower(documents);

% Lemmatize.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','Lemma');

% Remove stop words.
documents = removeStopWords(documents);

% Remove short words.
documents = removeShortWords(documents,2);

end

Функция модели

Функция model принимает в качестве входных входные данные dlX и параметры модели parametersи возвращает прогнозы для меток.

function dlY = model(dlX,parameters)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% GRU
inputWeights = parameters.gru.InputWeights;
recurrentWeights = parameters.gru.RecurrentWeights;
bias = parameters.gru.Bias;

numHiddenUnits = size(inputWeights,1)/3;
hiddenState = dlarray(zeros([numHiddenUnits 1]));

dlY = gru(dlX, hiddenState, inputWeights, recurrentWeights, bias,'DataFormat','CBT');

% Max pooling along time dimension
dlY = max(dlY,[],3);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Sigmoid
dlY = sigmoid(dlY);

end

Функция градиентов модели

modelGradients функция принимает в качестве входных данных мини-пакет входных данных dlX с соответствующими целями T содержит метки и возвращает градиенты потерь относительно обучаемых параметров, соответствующих потерь и сетевых выходов.

function [gradients,loss,dlYPred] = modelGradients(dlX,T,parameters)

dlYPred = model(dlX,parameters);

loss = crossentropy(dlYPred,T,'TargetCategories','independent','DataFormat','CB');

gradients = dlgradient(loss,parameters);

end

Функция прогнозирования модели

modelPredictions функция принимает за входные параметры модели, кодировку слов, массив маркированных документов, размер мини-пакета и максимальную длину последовательности и возвращает предсказания модели путем итерации по мини-пакетам указанного размера.

function dlYPred = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength)

inputSize = enc.NumWords + 1;

numObservations = numel(documents);
numIterations = ceil(numObservations / miniBatchSize);

numFeatures = size(parameters.fc.Weights,1);
dlYPred = zeros(numFeatures,numObservations,'like',parameters.fc.Weights);

for i = 1:numIterations
    
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    len = min(maxSequenceLength,max(doclength(documents(idx))));
    X = doc2sequence(enc,documents(idx), ...
        'PaddingValue',inputSize, ...
        'Length',len);
    X = cat(1,X{:});
    
    dlX = dlarray(X,'BTC');
    
    dlYPred(:,idx) = model(dlX,parameters);
end

end

Маркировка функции F-Score

Функция маркировки F-score [2] оценивает многоуровневую классификацию, фокусируясь на классификации по тексту с частичными совпадениями. Мера представляет собой нормализованную долю совпадающих меток от общего числа истинных и прогнозируемых меток, заданных

1N∑n=1N (2∑c=1CYncTnc∑c=1C (Ync + Tnc)), маркировка F-Score

где N и C соответствуют количеству наблюдений и классов соответственно, а Y и T соответствуют предсказаниям и целям соответственно.

function score = labelingFScore(Y,T)

numObservations = size(T,2);

scores = (2 * sum(Y .* T)) ./ sum(Y + T);
score = sum(scores) / numObservations;

end

Функция инициализации весов Glorot

initializeGlorot генерирует массив весов в соответствии с инициализацией Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

Функция инициализации весов Гаусса

initializeGaussian веса выборок функций из гауссова распределения со средним значением 0 и стандартным отклонением 0,01.

function parameter = initializeGaussian(sz)

parameter = randn(sz,'single') .* 0.01;

end

Функция встраивания

embedding функция отображает числовые индексы в соответствующий вектор, задаваемый входными весами.

function Z = embedding(X, weights)
% Reshape inputs into a vector.
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix.
Z = weights(:, X);

% Reshape outputs by separating batch and sequence dimensions.
Z = reshape(Z, [], N, T);
end

L2 Функция отсечения градиента нормы

thresholdL2Norm функция масштабирует входные градиенты так, чтобы их значения L2 нормы были равны заданному порогу градиента, когда значение L2 нормы градиента обучаемого параметра больше заданного порога.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end

Ссылки

  1. arXiv. «arXiv API». Доступ состоялся 15 января 2020 года. https://arxiv.org/help/api

  2. Соколова, Марина и Гай Лапальм. «Системный анализ показателей эффективности для классификационных задач». Обработка информации и управление 45, № 4 (2009): 427-437.

См. также

| | | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения) | (инструментарий для глубокого обучения)

Связанные темы