В этом примере показано, как классифицировать текстовые данные, которые имеют несколько независимых меток.
Для задач классификации, где может быть несколько независимых меток для каждого наблюдения — например, наклеивает научную статью — можно обучить модель глубокого обучения предсказывать вероятности для каждого независимого класса. Чтобы позволить сети изучить цели классификации мультиметок, можно оптимизировать потерю каждого класса независимо с помощью бинарной потери перекрестной энтропии.
Этот пример задает модель глубокого обучения, которая классифицирует предметные области, учитывая краткие обзоры математических бумаг, собранных с помощью arXiv API [1]. Модель состоит из встраивания слова и ГРУ, макс. объединяя операцию, полностью соединенные, и сигмоидальные операции.
Чтобы измерить уровень классификации мультиметок, можно использовать F-счет маркировки [2]. F-счет маркировки оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями. Мерой является нормированная пропорция соответствия с метками против общего количества истины и предсказанных меток.
Этот пример задает следующую модель:
Слово, встраивающее, который сопоставляет последовательность слов к последовательности числовых векторов.
Операция ГРУ, которая изучает зависимости между векторами встраивания.
Макс. операция объединения, которая уменьшает последовательность характеристических векторов к одному характеристическому вектору.
Полносвязный слой, который сопоставляет функции с двоичными выходами.
Сигмоидальная операция для изучения бинарной потери перекрестной энтропии между выходными параметрами и целевыми метками.
Эта схема показывает часть текстового распространения через архитектуру модели и вывода вектора вероятностей. Вероятности независимы, таким образом, они не должны суммировать одному.
Импортируйте набор кратких обзоров и подписей категорий из математических бумаг с помощью arXiv API. Задайте количество записей, чтобы импортировать использование importSize
переменная. Обратите внимание на то, что arXiv API является уровнем, ограниченным запросом 1 000 статей за один раз, и требует ожидания между запросами.
importSize = 50000;
Импортируйте первый набор записей.
url = "https://export.arxiv.org/oai2?verb=ListRecords" + ... "&set=math" + ... "&metadataPrefix=arXiv"; options = weboptions('Timeout',160); code = webread(url,options);
Проанализируйте возвращенное содержание XML и создайте массив htmlTree
объекты, содержащие информацию записи.
tree = htmlTree(code);
subtrees = findElement(tree,"record");
numel(subtrees)
Итеративно импортируйте больше фрагментов записей, пока необходимое количество не достигнуто, или больше нет записей. Чтобы продолжить импортировать записи из того, где вы кончили, используйте resumptionToken
припишите от предыдущего результата. Чтобы придерживаться ограничений скорости, наложенных arXiv API, добавьте задержку 20 секунд перед каждым запросом с помощью pause
функция.
while numel(subtrees) < importSize subtreeResumption = findElement(tree,"resumptionToken"); if isempty(subtreeResumption) break end resumptionToken = extractHTMLText(subtreeResumption); url = "https://export.arxiv.org/oai2?verb=ListRecords" + ... "&resumptionToken=" + resumptionToken; pause(20) code = webread(url,options); tree = htmlTree(code); subtrees = [subtrees; findElement(tree,"record")]; end
Извлеките краткие обзоры и метки от проанализированных деревьев HTML.
Найдите "<abstract>"
и "<categories>"
элементы с помощью findElement
функция.
subtreeAbstract = htmlTree(""); subtreeCategory = htmlTree(""); for i = 1:numel(subtrees) subtreeAbstract(i) = findElement(subtrees(i),"abstract"); subtreeCategory(i) = findElement(subtrees(i),"categories"); end
Извлеките текстовые данные из поддеревьев, содержащих краткие обзоры с помощью extractHTMLText
функция.
textData = extractHTMLText(subtreeAbstract);
Маркируйте и предварительно обработайте текстовые данные с помощью preprocessText
функция, перечисленная в конце примера.
documentsAll = preprocessText(textData); documentsAll(1:5)
ans = 5×1 tokenizedDocument: 72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concern tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson 22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant 18 tokens: "paper" "show" "compute" "$\lambda_{\alpha}$" "norm" "$\alpha\ge 0$" "dyadic" "grid" "result" "consequence" "description" "hardy" "space" "$h^p(r^n)$" "term" "dyadic" "special" "atom" 62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product pasting expansion contraction process utilized paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate 29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms
Извлеките метки из поддеревьев, содержащих метки.
strLabels = extractHTMLText(subtreeCategory);
labelsAll = arrayfun(@split,strLabels,'UniformOutput',false);
Удалите метки, которые не принадлежат "math"
набор.
for i = 1:numel(labelsAll) labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math.")); end
Визуализируйте некоторые классы, одним словом, облако. Найдите документы, соответствующие следующему:
Краткие обзоры помечены с "Комбинаторикой" и не теговые с "Statistics Theory"
Краткие обзоры, помеченные с "Теорией Статистики" и не теговый с "Combinatorics"
Краткие обзоры помечены с обоими "Combinatorics"
и "Statistics Theory"
Найдите индексы документа для каждой из групп, использующих ismember
функция.
idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll); idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll); idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);
Визуализируйте документы для каждой группы, одним словом, облако.
figure subplot(1,3,1) wordcloud(documentsAll(idxCO)); title("Combinatorics") subplot(1,3,2) wordcloud(documentsAll(idxST)); title("Statistics Theory") subplot(1,3,3) wordcloud(documentsAll(idxCOST)); title("Both")
Просмотрите количество классов.
classNames = unique(cat(1,labelsAll{:})); numClasses = numel(classNames)
numClasses = 32
Визуализируйте количество меток на документ с помощью гистограммы.
labelCounts = cellfun(@numel, labelsAll); figure histogram(labelCounts) xlabel("Number of Labels") ylabel("Frequency") title("Label Counts")
Разделите данные в разделы обучения и валидации с помощью cvpartition
функция. Протяните 10% данных для валидации путем установки 'HoldOut'
опция к 0,1.
cvp = cvpartition(numel(documentsAll),'HoldOut',0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));
labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));
Создайте объект кодирования слова, который кодирует учебные материалы как последовательности словарей. Задайте словарь этих 5 000 слов путем установки 'Order'
опция к 'frequency'
, и 'MaxNumWords'
опция к 5 000.
enc = wordEncoding(documentsTrain,'Order','frequency','MaxNumWords',5000)
enc = wordEncoding with properties: NumWords: 5000 Vocabulary: [1×5000 string]
Чтобы улучшить обучение, используйте следующие методы:
Когда обучение, усеченное документы длине, которая уменьшает объем дополнения используемого и не делает, действительно отбрасывает слишком много данных.
Обучайтесь в течение одной эпохи с документами, отсортированными по длине в порядке возрастания, затем переставьте данные каждая эпоха. Этот метод известен sortagrad.
Чтобы выбрать длину последовательности для усечения, визуализируйте длины документа в гистограмме и выберите значение, которое собирает большинство данных.
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) xlabel("Document Length") ylabel("Frequency") title("Document Lengths")
Большинство учебных материалов имеет меньше чем 175 лексем. Используйте 175 лексем в качестве целевой длины для усечения и дополнения.
maxSequenceLength = 175;
Чтобы использовать sortagrad метод, отсортируйте документы по длине в порядке возрастания.
[~,idx] = sort(documentLengths); documentsTrain = documentsTrain(idx); labelsTrain = labelsTrain(idx);
Задайте параметры для каждой из операций и включайте их в struct. Используйте формат parameters.OperationName.ParameterName
, где parameters
struct, OperationName
имя операции (например, "fc"
), и ParameterName
имя параметра (например, "Weights"
).
Создайте struct parameters
содержа параметры модели. Инициализируйте смещение нулями. Используйте следующие инициализаторы веса в операциях:
Для встраивания инициализируйте веса случайными нормальными значениями.
Для операции ГРУ инициализируйте веса с помощью initializeGlorot
функция, перечисленная в конце примера.
Для полностью операции connect, инициализируйте веса с помощью initializeGaussian
функция, перечисленная в конце примера.
embeddingDimension = 300; numHiddenUnits = 250; inputSize = enc.NumWords + 1; parameters = struct; parameters.emb.Weights = dlarray(randn([embeddingDimension inputSize])); parameters.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,embeddingDimension)); parameters.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits)); parameters.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single')); parameters.fc.Weights = dlarray(initializeGaussian([numClasses,numHiddenUnits])); parameters.fc.Bias = dlarray(zeros(numClasses,1,'single'));
Просмотрите parameters
struct ().
parameters
parameters = struct with fields:
emb: [1×1 struct]
gru: [1×1 struct]
fc: [1×1 struct]
Просмотрите параметры для операции ГРУ.
parameters.gru
ans = struct with fields:
InputWeights: [750×300 dlarray]
RecurrentWeights: [750×250 dlarray]
Bias: [750×1 dlarray]
Создайте функциональный model
, перечисленный в конце примера, который вычисляет выходные параметры модели глубокого обучения, описанной ранее. Функциональный model
берет в качестве входа входные данные dlX
и параметры модели parameters
. Сетевые выходные параметры предсказания для меток.
Создайте функциональный modelGradients
, перечисленный в конце примера, который берет в качестве входа мини-пакет входных данных dlX
и соответствующие цели T
содержание меток, и возвращает градиенты потери относительно настраиваемых параметров, соответствующей потери и сетевых выходных параметров.
Обучайтесь в течение 5 эпох с мини-пакетным размером 256.
numEpochs = 5; miniBatchSize = 256;
Обучите использование оптимизатора Адама, со скоростью обучения 0,01, и задайте затухание градиента и факторы затухания градиента в квадрате 0,5 и 0.999, соответственно.
learnRate = 0.01; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Отсеките градиенты с порогом 1 использования усечение градиента нормы.
gradientThreshold = 1;
Визуализируйте процесс обучения в графике.
plots = "training-progress";
Чтобы преобразовать вектор вероятностей к меткам, используйте метки с вероятностями выше, чем заданный порог. Задайте порог метки 0,5.
labelThreshold = 0.5;
Подтвердите сеть каждая эпоха.
numObservationsTrain = numel(documentsTrain); numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize); validationFrequency = numIterationsPerEpoch;
Обучайтесь на графическом процессоре, если вы доступны. Это требует Parallel Computing Toolbox™. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA®, графический процессор с вычисляет возможность 3.0 или выше.
executionEnvironment = "auto";
Обучите модель с помощью пользовательского учебного цикла.
В течение каждой эпохи, цикла по мини-пакетам данных. В конце каждой эпохи переставьте данные. В конце каждой итерации обновите график процесса обучения.
Для каждого мини-пакета:
Преобразуйте документы последовательностям словарей и преобразуйте метки в фиктивные переменные.
Преобразуйте последовательности в dlarray
объекты с базовым одним типом и указывают, что размерность маркирует 'BCT'
(пакет, канал, время).
Для обучения графического процессора преобразуйте в gpuArray
объекты.
Оцените градиенты модели и потерю с помощью dlfeval
и modelGradients
функция.
Отсеките градиенты.
Обновите сетевые параметры с помощью adamupdate
функция.
При необходимости подтвердите сеть с помощью modelPredictions
функция, перечисленная в конце примера.
Обновите учебный график.
Инициализируйте график процесса обучения.
if plots == "training-progress" figure % Labeling F-Score. subplot(2,1,1) lineFScoreTrain = animatedline('Color',[0 0.447 0.741]); lineFScoreValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 1]) xlabel("Iteration") ylabel("Labeling F-Score") grid on % Loss. subplot(2,1,2) lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on end
Инициализируйте параметры для оптимизатора Адама.
trailingAvg = []; trailingAvgSq = [];
Подготовьте данные о валидации. Создайте одногорячую закодированную матрицу, где ненулевые записи соответствуют меткам каждого наблюдения.
numObservationsValidation = numel(documentsValidation); TValidation = zeros(numClasses, numObservationsValidation, 'single'); for i = 1:numObservationsValidation [~,idx] = ismember(labelsValidation{i},classNames); TValidation(idx,i) = 1; end
Обучите модель.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; idx = (i-1)*miniBatchSize+1:i*miniBatchSize; % Read mini-batch of data and convert the labels to dummy % variables. documents = documentsTrain(idx); labels = labelsTrain(idx); % Convert documents to sequences. len = min(maxSequenceLength,max(doclength(documents))); X = doc2sequence(enc,documents, ... 'PaddingValue',inputSize, ... 'Length',len); X = cat(1,X{:}); % Dummify labels. T = zeros(numClasses, miniBatchSize, 'single'); for j = 1:miniBatchSize [~,idx2] = ismember(labels{j},classNames); T(idx2,j) = 1; end % Convert mini-batch of data to dlarray. dlX = dlarray(X,'BTC'); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,loss,dlYPred] = dlfeval(@modelGradients, dlX, T, parameters); % Gradient clipping. gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients); % Update the network parameters using the Adam optimizer. [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAvgSq,iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor); % Display the training progress. if plots == "training-progress" subplot(2,1,1) D = duration(0,0,toc(start),'Format','hh:mm:ss'); title("Epoch: " + epoch + ", Elapsed: " + string(D)) % Loss. addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) % Labeling F-score. YPred = extractdata(dlYPred) > labelThreshold; score = labelingFScore(YPred,T); addpoints(lineFScoreTrain,iteration,double(gather(score))) drawnow % Display validation metrics. if iteration == 1 || mod(iteration,validationFrequency) == 0 dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength); % Loss. lossValidation = crossentropy(dlYPredValidation,TValidation, ... 'TargetCategories','independent', ... 'DataFormat','CB'); addpoints(lineLossValidation,iteration,double(gather(extractdata(lossValidation)))) % Labeling F-score. YPredValidation = extractdata(dlYPredValidation) > labelThreshold; score = labelingFScore(YPredValidation,TValidation); addpoints(lineFScoreValidation,iteration,double(gather(score))) drawnow end end end % Shuffle data. idx = randperm(numObservationsTrain); documentsTrain = documentsTrain(idx); labelsTrain = labelsTrain(idx); end
Чтобы сделать предсказания на новом наборе данных, используйте modelPredictions
функция, перечисленная в конце примера. modelPredictions
функционируйте берет в качестве входа параметры модели, кодирование слова и массив маркируемых документов, и выводит предсказания модели, соответствующие заданному мини-пакетному размеру и максимальной длине последовательности.
dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);
Чтобы преобразовать сетевые выходные параметры в массив меток, найдите метки с баллами выше, чем заданный порог метки.
YPredValidation = extractdata(dlYPredValidation) > labelThreshold;
Чтобы оценить производительность, вычислите F-счет маркировки с помощью labelingFScore
функция, перечисленная в конце примера. F-счет маркировки оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями.
score = labelingFScore(YPredValidation,TValidation)
score = single
0.5852
Просмотрите эффект порога маркировки на F-счете маркировки путем попытки области значений значений для порога и сравнения результатов.
thr = linspace(0,1,10); score = zeros(size(thr)); for i = 1:numel(thr) YPredValidationThr = extractdata(dlYPredValidation) >= thr(i); score(i) = labelingFScore(YPredValidationThr,TValidation); end figure plot(thr,score) xline(labelThreshold,'r--'); xlabel("Threshold") ylabel("Labeling F-Score") title("Effect of Labeling Threshold")
Чтобы визуализировать правильные предсказания классификатора, вычислите количества истинных положительных сторон. Положительная истина является экземпляром классификатора, правильно предсказывая конкретный класс для наблюдения.
Y = YPredValidation; T = TValidation; numTruePositives = sum(T & Y,2); numObservationsPerClass = sum(T,2); truePositiveRates = numTruePositives ./ numObservationsPerClass;
Визуализируйте количества истинных положительных сторон для каждого класса в гистограмме.
figure [~,idx] = sort(truePositiveRates,'descend'); histogram('Categories',classNames(idx),'BinCounts',truePositiveRates(idx)) xlabel("Category") ylabel("True Positive Rate") title("True Positive Rates")
Визуализируйте экземпляры, где классификатор предсказывает неправильно путем показа распределения истинных положительных сторон, ложных положительных сторон и ложных отрицательных сторон. Положительная ложь является экземпляром классификатора, присваивающего конкретный неправильный класс наблюдению. Ложное отрицание является экземпляром классификатора, не удающегося присваивать конкретный правильный класс наблюдению.
Создайте матрицу беспорядка показ истинных положительных, ложных положительных, и ложных отрицательных количеств:
Для каждого класса отобразите истинные положительные количества на диагонали.
Для каждой пары классов (i, j), отображают количество экземпляров лжи, положительной для j, когда экземпляр является также ложным отрицанием поскольку i.
Таким образом, матрица беспорядка с элементами, данными:
Вычислите ложные отрицательные стороны и ложные положительные стороны.
falseNegatives = T & ~Y; falsePositives = ~T & Y;
Вычислите недиагональные элементы.
falseNegatives = permute(falseNegatives,[3 2 1]); numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2); numConditionalFalsePositives = squeeze(numConditionalFalsePositives); tpfnMatrix = numConditionalFalsePositives;
Установите диагональные элементы на истинные положительные количества.
idxDiagonal = 1:numClasses+1:numClasses^2; tpfnMatrix(idxDiagonal) = numTruePositives;
Визуализируйте истинные положительные и ложные положительные количества в матрице беспорядка использование confusionchart
функционируйте и отсортируйте матрицу, таким образом, что элементы на диагонали в порядке убывания.
figure cm = confusionchart(tpfnMatrix,classNames); sortClasses(cm,"descending-diagonal"); title("True Positives, False Positives")
Чтобы просмотреть матрицу более подробно, откройте этот пример как live скрипт и откройте фигуру в новом окне.
preprocessText
функция маркирует и предварительно обрабатывает входные текстовые данные с помощью следующих шагов:
Маркируйте текст с помощью tokenizedDocument
функция. Извлеките математические уравнения как одну лексему с помощью 'RegularExpressions'
опция путем определения регулярного выражения "\$.*?\$"
, который получает текст, появляющийся между двумя символами "$".
Сотрите пунктуацию с помощью erasePunctuation
функция.
Преобразуйте текст в нижний регистр с помощью lower
функция.
Удалите слова остановки с помощью removeStopWords
функция.
Lemmatize текст с помощью normalizeWords
функция с 'Style'
набор опции к 'lemma'
.
function documents = preprocessText(textData) % Tokenize the text. regularExpressions = table; regularExpressions.Pattern = "\$.*?\$"; regularExpressions.Type = "equation"; documents = tokenizedDocument(textData,'RegularExpressions',regularExpressions); % Erase punctuation. documents = erasePunctuation(documents); % Convert to lowercase. documents = lower(documents); % Lemmatize. documents = addPartOfSpeechDetails(documents); documents = normalizeWords(documents,'Style','Lemma'); % Remove stop words. documents = removeStopWords(documents); % Remove short words. documents = removeShortWords(documents,2); end
Функциональный model
берет в качестве входа входные данные dlX
и параметры модели parameters
, и возвращает предсказания для меток.
function dlY = model(dlX,parameters) % Embedding weights = parameters.emb.Weights; dlX = embedding(dlX, weights); % GRU inputWeights = parameters.gru.InputWeights; recurrentWeights = parameters.gru.RecurrentWeights; bias = parameters.gru.Bias; numHiddenUnits = size(inputWeights,1)/3; hiddenState = dlarray(zeros([numHiddenUnits 1])); dlY = gru(dlX, hiddenState, inputWeights, recurrentWeights, bias,'DataFormat','CBT'); % Max pooling along time dimension dlY = max(dlY,[],3); % Fully connect weights = parameters.fc.Weights; bias = parameters.fc.Bias; dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB'); % Sigmoid dlY = sigmoid(dlY); end
modelGradients
функционируйте берет в качестве входа мини-пакет входных данных dlX
с соответствующими целями T
содержание меток и возвращает градиенты потери относительно настраиваемых параметров, соответствующей потери и сетевых выходных параметров.
function [gradients,loss,dlYPred] = modelGradients(dlX,T,parameters) dlYPred = model(dlX,parameters); loss = crossentropy(dlYPred,T,'TargetCategories','independent','DataFormat','CB'); gradients = dlgradient(loss,parameters); end
modelPredictions
функционируйте берет в качестве входа параметры модели, кодирование слова, массив маркируемых документов, мини-пакетного размера и максимальной длины последовательности, и возвращает предсказания модели путем итерации по мини-пакетам заданного размера.
function dlYPred = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength) inputSize = enc.NumWords + 1; numObservations = numel(documents); numIterations = ceil(numObservations / miniBatchSize); numFeatures = size(parameters.fc.Weights,1); dlYPred = zeros(numFeatures,numObservations,'like',parameters.fc.Weights); for i = 1:numIterations idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); len = min(maxSequenceLength,max(doclength(documents(idx)))); X = doc2sequence(enc,documents(idx), ... 'PaddingValue',inputSize, ... 'Length',len); X = cat(1,X{:}); dlX = dlarray(X,'BTC'); dlYPred(:,idx) = model(dlX,parameters); end end
Функция F-счета маркировки [2] оценивает классификацию мультиметок путем фокусировки на классификации на текст с частичными соответствиями. Мерой является нормированная пропорция соответствия с метками против общего количества истинных и предсказанных меток, данных
где N и C соответствуют количеству наблюдений и классов, соответственно, и Y и T соответствуют предсказаниям и целям, соответственно.
function score = labelingFScore(Y,T) numObservations = size(T,2); scores = (2 * sum(Y .* T)) ./ sum(Y + T); score = sum(scores) / numObservations; end
initializeGlorot
функция генерирует массив весов согласно инициализации Glorot.
function weights = initializeGlorot(numOut, numIn) varWeights = sqrt( 6 / (numIn + numOut) ); weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1); end
initializeGaussian
функциональные демонстрационные веса от Распределения Гаусса со средним значением 0 и стандартным отклонением 0.01.
function parameter = initializeGaussian(sz) parameter = randn(sz,'single') .* 0.01; end
embedding
функционируйте сопоставляет числовые индексы с соответствующим вектором, данным входными весами.
function Z = embedding(X, weights) % Reshape inputs into a vector. [N, T] = size(X, 2:3); X = reshape(X, N*T, 1); % Index into embedding matrix. Z = weights(:, X); % Reshape outputs by separating batch and sequence dimensions. Z = reshape(Z, [], N, T); end
thresholdL2Norm
функционируйте масштабирует входные градиенты так, чтобы их значения нормы равняются заданному порогу градиента когда значение нормы градиента настраиваемого параметра больше, чем заданный порог.
function gradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2)); if gradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm); end end
arXiv. "arXiv API". Полученный доступ 15 января 2020. https://arxiv.org/help/api
Соколова, Марина, и Гай Лэпэйлм. "Анализ Sytematic Критериев качества работы для Задач Классификации". Обработка информации & управление 45, № 4 (2009): 427–437.
adamupdate
| dlarray
| dlfeval
| dlgradient
| dlupdate
| doc2sequence
| extractHTMLText
| fullyconnect
| gru
| htmlTree
| tokenizedDocument
| wordEncoding