Мультимаркируйте Text Classification Using Deep Learning

В этом примере показано, как классифицировать текстовые данные, которые имеют несколько независимых меток.

Для задач классификации, где может быть несколько независимых меток для каждого наблюдения — например, наклеивает научную статью — можно обучить модель глубокого обучения предсказывать вероятности для каждого независимого класса. Чтобы позволить сети изучить цели классификации мультиметок, можно оптимизировать потерю каждого класса независимо с помощью бинарной потери перекрестной энтропии.

Этот пример задает модель глубокого обучения, которая классифицирует предметные области, учитывая краткие обзоры математических бумаг, собранных с помощью arXiv API [1]. Модель состоит из встраивания слова и ГРУ, макс. объединяя операцию, полностью соединенные, и сигмоидальные операции.

Чтобы измерить уровень классификации мультиметок, можно использовать F-счет маркировки [2]. F-счет маркировки оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями. Мерой является нормированная пропорция соответствия с метками против общего количества истины и предсказанных меток.

Этот пример задает следующую модель:

  • Слово, встраивающее, который сопоставляет последовательность слов к последовательности числовых векторов.

  • Операция ГРУ, которая изучает зависимости между векторами встраивания.

  • Макс. операция объединения, которая уменьшает последовательность характеристических векторов к одному характеристическому вектору.

  • Полносвязный слой, который сопоставляет функции с двоичными выходами.

  • Сигмоидальная операция для изучения бинарной потери перекрестной энтропии между выходными параметрами и целевыми метками.

Эта схема показывает часть текстового распространения через архитектуру модели и вывода вектора вероятностей. Вероятности независимы, таким образом, они не должны суммировать одному.

Импортируйте текстовые данные

Импортируйте набор кратких обзоров и подписей категорий из математических бумаг с помощью arXiv API. Задайте количество записей, чтобы импортировать использование importSize переменная. Обратите внимание на то, что arXiv API является уровнем, ограниченным запросом 1 000 статей за один раз, и требует ожидания между запросами.

importSize = 50000;

Импортируйте первый набор записей.

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";
options = weboptions('Timeout',160);
code = webread(url,options);

Проанализируйте возвращенное содержание XML и создайте массив htmlTree объекты, содержащие информацию записи.

tree = htmlTree(code);
subtrees = findElement(tree,"record");
numel(subtrees)

Итеративно импортируйте больше фрагментов записей, пока необходимое количество не достигнуто, или больше нет записей. Чтобы продолжить импортировать записи из того, где вы кончили, используйте resumptionToken припишите от предыдущего результата. Чтобы придерживаться ограничений скорости, наложенных arXiv API, добавьте задержку 20 секунд перед каждым запросом с помощью pause функция.

while numel(subtrees) < importSize
    subtreeResumption = findElement(tree,"resumptionToken");
    
    if isempty(subtreeResumption)
        break
    end
    
    resumptionToken = extractHTMLText(subtreeResumption);
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    code = webread(url,options);
    
    tree = htmlTree(code);
    
    subtrees = [subtrees; findElement(tree,"record")];
end

Извлеките и предварительно обработайте текстовые данные

Извлеките краткие обзоры и метки от проанализированных деревьев HTML.

Найдите "<abstract>" и "<categories>" элементы с помощью findElement функция.

subtreeAbstract = htmlTree("");
subtreeCategory = htmlTree("");

for i = 1:numel(subtrees)
    subtreeAbstract(i) = findElement(subtrees(i),"abstract");
    subtreeCategory(i) = findElement(subtrees(i),"categories");
end

Извлеките текстовые данные из поддеревьев, содержащих краткие обзоры с помощью extractHTMLText функция.

textData = extractHTMLText(subtreeAbstract);

Маркируйте и предварительно обработайте текстовые данные с помощью preprocessText функция, перечисленная в конце примера.

documentsAll = preprocessText(textData);
documentsAll(1:5)
ans = 
  5×1 tokenizedDocument:

    72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concern tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson
    22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant
    18 tokens: "paper" "show" "compute" "$\lambda_{\alpha}$" "norm" "$\alpha\ge 0$" "dyadic" "grid" "result" "consequence" "description" "hardy" "space" "$h^p(r^n)$" "term" "dyadic" "special" "atom"
    62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product pasting expansion contraction process utilized paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate
    29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms

Извлеките метки из поддеревьев, содержащих метки.

strLabels = extractHTMLText(subtreeCategory);
labelsAll = arrayfun(@split,strLabels,'UniformOutput',false);

Удалите метки, которые не принадлежат "math" набор.

for i = 1:numel(labelsAll)
    labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math."));
end

Визуализируйте некоторые классы, одним словом, облако. Найдите документы, соответствующие следующему:

  • Краткие обзоры помечены с "Комбинаторикой" и не теговые с "Statistics Theory"

  • Краткие обзоры, помеченные с "Теорией Статистики" и не теговый с "Combinatorics"

  • Краткие обзоры помечены с обоими "Combinatorics" и "Statistics Theory"

Найдите индексы документа для каждой из групп, использующих ismember функция.

idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll);
idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll);
idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);

Визуализируйте документы для каждой группы, одним словом, облако.

figure
subplot(1,3,1)
wordcloud(documentsAll(idxCO));
title("Combinatorics")

subplot(1,3,2)
wordcloud(documentsAll(idxST));
title("Statistics Theory")

subplot(1,3,3)
wordcloud(documentsAll(idxCOST));
title("Both")

Просмотрите количество классов.

classNames = unique(cat(1,labelsAll{:}));
numClasses = numel(classNames)
numClasses = 32

Визуализируйте количество меток на документ с помощью гистограммы.

labelCounts = cellfun(@numel, labelsAll);
figure
histogram(labelCounts)
xlabel("Number of Labels")
ylabel("Frequency")
title("Label Counts")

Подготовьте текстовые данные к глубокому обучению

Разделите данные в разделы обучения и валидации с помощью cvpartition функция. Протяните 10% данных для валидации путем установки 'HoldOut' опция к 0,1.

cvp = cvpartition(numel(documentsAll),'HoldOut',0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));

labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));

Создайте объект кодирования слова, который кодирует учебные материалы как последовательности словарей. Задайте словарь этих 5 000 слов путем установки 'Order' опция к 'frequency', и 'MaxNumWords' опция к 5 000.

enc = wordEncoding(documentsTrain,'Order','frequency','MaxNumWords',5000)
enc = 
  wordEncoding with properties:

      NumWords: 5000
    Vocabulary: [1×5000 string]

Чтобы улучшить обучение, используйте следующие методы:

  1. Когда обучение, усеченное документы длине, которая уменьшает объем дополнения используемого и не делает, действительно отбрасывает слишком много данных.

  2. Обучайтесь в течение одной эпохи с документами, отсортированными по длине в порядке возрастания, затем переставьте данные каждая эпоха. Этот метод известен sortagrad.

Чтобы выбрать длину последовательности для усечения, визуализируйте длины документа в гистограмме и выберите значение, которое собирает большинство данных.

documentLengths = doclength(documentsTrain);

figure
histogram(documentLengths)
xlabel("Document Length")
ylabel("Frequency")
title("Document Lengths")

Большинство учебных материалов имеет меньше чем 175 лексем. Используйте 175 лексем в качестве целевой длины для усечения и дополнения.

maxSequenceLength = 175;

Чтобы использовать sortagrad метод, отсортируйте документы по длине в порядке возрастания.

[~,idx] = sort(documentLengths);
documentsTrain = documentsTrain(idx);
labelsTrain = labelsTrain(idx);

Задайте и инициализируйте параметры модели

Задайте параметры для каждой из операций и включайте их в struct. Используйте формат parameters.OperationName.ParameterName, где parameters struct, OperationName имя операции (например, "fc"), и ParameterName имя параметра (например, "Weights").

Создайте struct parameters содержа параметры модели. Инициализируйте смещение нулями. Используйте следующие инициализаторы веса в операциях:

  • Для встраивания инициализируйте веса случайными нормальными значениями.

  • Для операции ГРУ инициализируйте веса с помощью initializeGlorot функция, перечисленная в конце примера.

  • Для полностью операции connect, инициализируйте веса с помощью initializeGaussian функция, перечисленная в конце примера.

embeddingDimension = 300;
numHiddenUnits = 250;
inputSize = enc.NumWords + 1;

parameters = struct;
parameters.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parameters.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,embeddingDimension));
parameters.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parameters.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

parameters.fc.Weights = dlarray(initializeGaussian([numClasses,numHiddenUnits]));
parameters.fc.Bias = dlarray(zeros(numClasses,1,'single'));

Просмотрите parameters struct ().

parameters
parameters = struct with fields:
    emb: [1×1 struct]
    gru: [1×1 struct]
     fc: [1×1 struct]

Просмотрите параметры для операции ГРУ.

parameters.gru
ans = struct with fields:
        InputWeights: [750×300 dlarray]
    RecurrentWeights: [750×250 dlarray]
                Bias: [750×1 dlarray]

Функция модели Define

Создайте функциональный model, перечисленный в конце примера, который вычисляет выходные параметры модели глубокого обучения, описанной ранее. Функциональный model берет в качестве входа входные данные dlX и параметры модели parameters. Сетевые выходные параметры предсказания для меток.

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет в качестве входа мини-пакет входных данных dlX и соответствующие цели T содержание меток, и возвращает градиенты потери относительно настраиваемых параметров, соответствующей потери и сетевых выходных параметров.

Задайте опции обучения

Обучайтесь в течение 5 эпох с мини-пакетным размером 256.

numEpochs = 5;
miniBatchSize = 256;

Обучите использование оптимизатора Адама, со скоростью обучения 0,01, и задайте затухание градиента и факторы затухания градиента в квадрате 0,5 и 0.999, соответственно.

learnRate = 0.01;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Отсеките градиенты с порогом 1 использования L2 усечение градиента нормы.

gradientThreshold = 1;

Визуализируйте процесс обучения в графике.

plots = "training-progress";

Чтобы преобразовать вектор вероятностей к меткам, используйте метки с вероятностями выше, чем заданный порог. Задайте порог метки 0,5.

labelThreshold = 0.5;

Подтвердите сеть каждая эпоха.

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
validationFrequency = numIterationsPerEpoch;

Обучайтесь на графическом процессоре, если вы доступны. Это требует Parallel Computing Toolbox™. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.

executionEnvironment = "auto";

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

В течение каждой эпохи, цикла по мини-пакетам данных. В конце каждой эпохи переставьте данные. В конце каждой итерации обновите график процесса обучения.

Для каждого мини-пакета:

  • Преобразуйте документы последовательностям словарей и преобразуйте метки в фиктивные переменные.

  • Преобразуйте последовательности в dlarray объекты с базовым одним типом и указывают, что размерность маркирует 'BCT' (пакет, канал, время).

  • Для обучения графического процессора преобразуйте в gpuArray объекты.

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функция.

  • Отсеките градиенты.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • При необходимости подтвердите сеть с помощью modelPredictions функция, перечисленная в конце примера.

  • Обновите учебный график.

Инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    
    % Labeling F-Score.
    subplot(2,1,1)
    lineFScoreTrain = animatedline('Color',[0 0.447 0.741]);
    lineFScoreValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Iteration")
    ylabel("Labeling F-Score")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Инициализируйте параметры для оптимизатора Адама.

trailingAvg = [];
trailingAvgSq = [];

Подготовьте данные о валидации. Создайте одногорячую закодированную матрицу, где ненулевые записи соответствуют меткам каждого наблюдения.

numObservationsValidation = numel(documentsValidation);
TValidation = zeros(numClasses, numObservationsValidation, 'single');
for i = 1:numObservationsValidation
    [~,idx] = ismember(labelsValidation{i},classNames);
    TValidation(idx,i) = 1;
end

Обучите модель.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        documents = documentsTrain(idx);
        labels = labelsTrain(idx);
        
        % Convert documents to sequences.
        len = min(maxSequenceLength,max(doclength(documents)));
        X = doc2sequence(enc,documents, ...
            'PaddingValue',inputSize, ...
            'Length',len);
        X = cat(1,X{:});
        
        % Dummify labels.
        T = zeros(numClasses, miniBatchSize, 'single');
        for j = 1:miniBatchSize
            [~,idx2] = ismember(labels{j},classNames);
            T(idx2,j) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X,'BTC');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss,dlYPred] = dlfeval(@modelGradients, dlX, T, parameters);
        
        % Gradient clipping.
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            subplot(2,1,1)
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            % Loss.
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            
            % Labeling F-score.
            YPred = extractdata(dlYPred) > labelThreshold;
            score = labelingFScore(YPred,T);
            addpoints(lineFScoreTrain,iteration,double(gather(score)))
            
            drawnow
            
            % Display validation metrics.
            if iteration == 1 || mod(iteration,validationFrequency) == 0
                dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);
                
                % Loss.
                lossValidation = crossentropy(dlYPredValidation,TValidation, ...
                    'TargetCategories','independent', ...
                    'DataFormat','CB');
                addpoints(lineLossValidation,iteration,double(gather(extractdata(lossValidation))))
                
                % Labeling F-score.
                YPredValidation = extractdata(dlYPredValidation) > labelThreshold;
                score = labelingFScore(YPredValidation,TValidation);
                addpoints(lineFScoreValidation,iteration,double(gather(score)))
                
                drawnow
            end
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservationsTrain);
    documentsTrain = documentsTrain(idx);
    labelsTrain = labelsTrain(idx);
end

Тестовая модель

Чтобы сделать предсказания на новом наборе данных, используйте modelPredictions функция, перечисленная в конце примера. modelPredictions функционируйте берет в качестве входа параметры модели, кодирование слова и массив маркируемых документов, и выводит предсказания модели, соответствующие заданному мини-пакетному размеру и максимальной длине последовательности.

dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

Чтобы преобразовать сетевые выходные параметры в массив меток, найдите метки с баллами выше, чем заданный порог метки.

YPredValidation = extractdata(dlYPredValidation) > labelThreshold;

Чтобы оценить производительность, вычислите F-счет маркировки с помощью labelingFScore функция, перечисленная в конце примера. F-счет маркировки оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями.

score = labelingFScore(YPredValidation,TValidation)
score = single
    0.5852

Просмотрите эффект порога маркировки на F-счете маркировки путем попытки области значений значений для порога и сравнения результатов.

thr = linspace(0,1,10);
score = zeros(size(thr));
for i = 1:numel(thr)
    YPredValidationThr = extractdata(dlYPredValidation) >= thr(i);
    score(i) = labelingFScore(YPredValidationThr,TValidation);
end

figure
plot(thr,score)
xline(labelThreshold,'r--');
xlabel("Threshold")
ylabel("Labeling F-Score")
title("Effect of Labeling Threshold")

Визуализируйте предсказания

Чтобы визуализировать правильные предсказания классификатора, вычислите количества истинных положительных сторон. Положительная истина является экземпляром классификатора, правильно предсказывая конкретный класс для наблюдения.

Y = YPredValidation;
T = TValidation;

numTruePositives = sum(T & Y,2);

numObservationsPerClass = sum(T,2);
truePositiveRates = numTruePositives ./ numObservationsPerClass;

Визуализируйте количества истинных положительных сторон для каждого класса в гистограмме.

figure
[~,idx] = sort(truePositiveRates,'descend');
histogram('Categories',classNames(idx),'BinCounts',truePositiveRates(idx))
xlabel("Category")
ylabel("True Positive Rate")
title("True Positive Rates")

Визуализируйте экземпляры, где классификатор предсказывает неправильно путем показа распределения истинных положительных сторон, ложных положительных сторон и ложных отрицательных сторон. Положительная ложь является экземпляром классификатора, присваивающего конкретный неправильный класс наблюдению. Ложное отрицание является экземпляром классификатора, не удающегося присваивать конкретный правильный класс наблюдению.

Создайте матрицу беспорядка показ истинных положительных, ложных положительных, и ложных отрицательных количеств:

  • Для каждого класса отобразите истинные положительные количества на диагонали.

  • Для каждой пары классов (i, j), отображают количество экземпляров лжи, положительной для j, когда экземпляр является также ложным отрицанием поскольку i.

Таким образом, матрица беспорядка с элементами, данными:

TPFNij={numTruePositives(i),if i=jnumFalsePositives(j|i   ложное отрицание),if ijИстинные положительные, ложные отрицательные уровни

Вычислите ложные отрицательные стороны и ложные положительные стороны.

falseNegatives = T & ~Y;
falsePositives = ~T & Y;

Вычислите недиагональные элементы.

falseNegatives = permute(falseNegatives,[3 2 1]);
numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2);
numConditionalFalsePositives = squeeze(numConditionalFalsePositives);

tpfnMatrix = numConditionalFalsePositives;

Установите диагональные элементы на истинные положительные количества.

idxDiagonal = 1:numClasses+1:numClasses^2;
tpfnMatrix(idxDiagonal) = numTruePositives;

Визуализируйте истинные положительные и ложные положительные количества в матрице беспорядка использование confusionchart функционируйте и отсортируйте матрицу, таким образом, что элементы на диагонали в порядке убывания.

figure
cm = confusionchart(tpfnMatrix,classNames);
sortClasses(cm,"descending-diagonal");
title("True Positives, False Positives")

Чтобы просмотреть матрицу более подробно, откройте этот пример как live скрипт и откройте фигуру в новом окне.

Предварительно обработайте текстовую функцию

preprocessText функция маркирует и предварительно обрабатывает входные текстовые данные с помощью следующих шагов:

  1. Маркируйте текст с помощью tokenizedDocument функция. Извлеките математические уравнения как одну лексему с помощью 'RegularExpressions' опция путем определения регулярного выражения "\$.*?\$", который получает текст, появляющийся между двумя символами "$".

  2. Сотрите пунктуацию с помощью erasePunctuation функция.

  3. Преобразуйте текст в нижний регистр с помощью lower функция.

  4. Удалите слова остановки с помощью removeStopWords функция.

  5. Lemmatize текст с помощью normalizeWords функция с 'Style' набор опции к 'lemma'.

function documents = preprocessText(textData)

% Tokenize the text.
regularExpressions = table;
regularExpressions.Pattern = "\$.*?\$";
regularExpressions.Type = "equation";

documents = tokenizedDocument(textData,'RegularExpressions',regularExpressions);

% Erase punctuation.
documents = erasePunctuation(documents);

% Convert to lowercase.
documents = lower(documents);

% Lemmatize.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','Lemma');

% Remove stop words.
documents = removeStopWords(documents);

% Remove short words.
documents = removeShortWords(documents,2);

end

Функция модели

Функциональный model берет в качестве входа входные данные dlX и параметры модели parameters, и возвращает предсказания для меток.

function dlY = model(dlX,parameters)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% GRU
inputWeights = parameters.gru.InputWeights;
recurrentWeights = parameters.gru.RecurrentWeights;
bias = parameters.gru.Bias;

numHiddenUnits = size(inputWeights,1)/3;
hiddenState = dlarray(zeros([numHiddenUnits 1]));

dlY = gru(dlX, hiddenState, inputWeights, recurrentWeights, bias,'DataFormat','CBT');

% Max pooling along time dimension
dlY = max(dlY,[],3);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Sigmoid
dlY = sigmoid(dlY);

end

Функция градиентов модели

modelGradients функционируйте берет в качестве входа мини-пакет входных данных dlX с соответствующими целями T содержание меток и возвращает градиенты потери относительно настраиваемых параметров, соответствующей потери и сетевых выходных параметров.

function [gradients,loss,dlYPred] = modelGradients(dlX,T,parameters)

dlYPred = model(dlX,parameters);

loss = crossentropy(dlYPred,T,'TargetCategories','independent','DataFormat','CB');

gradients = dlgradient(loss,parameters);

end

Функция предсказаний модели

modelPredictions функционируйте берет в качестве входа параметры модели, кодирование слова, массив маркируемых документов, мини-пакетного размера и максимальной длины последовательности, и возвращает предсказания модели путем итерации по мини-пакетам заданного размера.

function dlYPred = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength)

inputSize = enc.NumWords + 1;

numObservations = numel(documents);
numIterations = ceil(numObservations / miniBatchSize);

numFeatures = size(parameters.fc.Weights,1);
dlYPred = zeros(numFeatures,numObservations,'like',parameters.fc.Weights);

for i = 1:numIterations
    
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    len = min(maxSequenceLength,max(doclength(documents(idx))));
    X = doc2sequence(enc,documents(idx), ...
        'PaddingValue',inputSize, ...
        'Length',len);
    X = cat(1,X{:});
    
    dlX = dlarray(X,'BTC');
    
    dlYPred(:,idx) = model(dlX,parameters);
end

end

Маркировка функции F-счета

Функция F-счета маркировки [2] оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями. Мерой является нормированная пропорция соответствия с метками против общего количества истинных и предсказанных меток, данных

1Nn=1N(2c=1CYncTncc=1C(Ync+Tnc)),Маркировка F-счета

где N и C соответствуют количеству наблюдений и классов, соответственно, и Y и T соответствуют предсказаниям и целям, соответственно.

function score = labelingFScore(Y,T)

numObservations = size(T,2);

scores = (2 * sum(Y .* T)) ./ sum(Y + T);
score = sum(scores) / numObservations;

end

Функция инициализации весов Glorot

initializeGlorot функция генерирует массив весов согласно инициализации Glorot.

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

Гауссова функция инициализации весов

initializeGaussian функциональные демонстрационные веса от Распределения Гаусса со средним значением 0 и стандартным отклонением 0.01.

function parameter = initializeGaussian(sz)

parameter = randn(sz,'single') .* 0.01;

end

Встраивание функции

embedding функционируйте сопоставляет числовые индексы с соответствующим вектором, данным входными весами.

function Z = embedding(X, weights)
% Reshape inputs into a vector.
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix.
Z = weights(:, X);

% Reshape outputs by separating batch and sequence dimensions.
Z = reshape(Z, [], N, T);
end

L2 Функция усечения градиента нормы

thresholdL2Norm функционируйте масштабирует входные градиенты так, чтобы их L2 значения нормы равняются заданному порогу градиента когда L2 значение нормы градиента настраиваемого параметра больше, чем заданный порог.

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end

Ссылки

  1. arXiv. "arXiv API". Полученный доступ 15 января 2020. https://arxiv.org/help/api

  2. Соколова, Марина, и Гай Лэпэйлм. "Анализ Sytematic Критериев качества работы для Задач Классификации". Обработка информации & управление 45, № 4 (2009): 427–437.

Смотрите также

| | | | | | | |

Похожие темы