Классифицируйте текстовые данные Используя пользовательский учебный цикл

В этом примере показано, как классифицировать текстовые данные с помощью глубокого обучения двунаправленная длинная краткосрочная сеть (BiLSTM) памяти с пользовательским учебным циклом.

Когда обучение нейронная сеть для глубокого обучения с помощью trainNetwork функция, если trainingOptions не предоставляет возможности, в которых вы нуждаетесь (например, пользовательское расписание скорости обучения), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования. Для примера, показывающего, как классифицировать текстовые данные с помощью trainNetwork функционируйте, смотрите, Классифицируют текстовые Данные Используя Глубокое обучение.

Этот пример обучает сеть, чтобы классифицировать текстовые данные с основанным на времени расписанием скорости обучения затухания: для каждой итерации решатель использует скорость обучения, данную ρt=ρ01+kt, где t является номером итерации, ρ0 начальная скорость обучения, и k является затуханием.

Импортируйте данные

Импортируйте данные об отчетах фабрики. Эти данные содержат помеченные текстовые описания событий фабрики. Чтобы импортировать текстовые данные как строки, задайте тип текста, чтобы быть 'string'.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Цель этого примера состоит в том, чтобы классифицировать события меткой в Category столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальный.

data.Category = categorical(data.Category);

Просмотрите распределение классов в данных с помощью гистограммы.

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

Следующий шаг должен разделить его в наборы для обучения и валидации. Разделите данные в учебный раздел и протянутый раздел для валидации и тестирования. Задайте процент затяжки, чтобы быть 20%.

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

Извлеките текстовые данные и метки из разделенных таблиц.

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

Чтобы проверять, что вы импортировали данные правильно, визуализируйте учебные текстовые данные с помощью облака слова.

figure
wordcloud(textDataTrain);
title("Training Data")

Просмотрите количество классов.

classes = categories(YTrain);
numClasses = numel(classes)
numClasses = 4

Предварительно обработайте текстовые данные

Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функциональный preprocessText, перечисленный в конце примера, выполняет эти шаги:

  1. Маркируйте текст с помощью tokenizedDocument.

  2. Преобразуйте текст в нижний регистр с помощью lower.

  3. Сотрите пунктуацию с помощью erasePunctuation.

Предварительно обработайте обучающие данные и данные о валидации с помощью preprocessText функция.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Просмотрите первые несколько предварительно обработанных учебных материалов.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses
     9 tokens: burst pipe in the constructing agent is spraying coolant

Создайте один datastore, который содержит и документы и метки путем создания arrayDatastore объекты, затем комбинируя их использующий combine функция.

dsDocumentsTrain = arrayDatastore(documentsTrain,'OutputType','cell');
dsYTrain = arrayDatastore(YTrain,'OutputType','cell');
dsTrain = combine(dsDocumentsTrain,dsYTrain);

Создайте datastore для данных о валидации с помощью тех же шагов.

dsDocumentsValidation = arrayDatastore(documentsValidation,'OutputType','cell');
dsYValidation = arrayDatastore(YValidation,'OutputType','cell');
dsValidation = combine(dsDocumentsValidation,dsYValidation);

Создайте Word Encoding

Чтобы ввести документы в сеть BiLSTM, используйте кодирование слова, чтобы преобразовать документы в последовательности числовых индексов.

Чтобы создать кодирование слова, используйте wordEncoding функция.

enc = wordEncoding(documentsTrain)
enc = 
  wordEncoding with properties:

      NumWords: 421
    Vocabulary: [1×421 string]

Сеть Define

Задайте архитектуру сети BiLSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на 1. Затем включайте слой встраивания слова размерности 25 и то же количество слов как кодирование слова. Затем включайте слой BiLSTM и определите номер скрытых модулей к 40. Чтобы использовать слой BiLSTM для проблемы классификации последовательностей к метке, установите режим вывода на 'last'. Наконец, добавьте полносвязный слой с тем же размером как количество классов и softmax слой.

inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;

numWords = enc.NumWords;

layers = [
    sequenceInputLayer(inputSize,'Name','in')
    wordEmbeddingLayer(embeddingDimension,numWords,'Name','emb')
    bilstmLayer(numHiddenUnits,'OutputMode','last','Name','bilstm')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','sm')]
layers = 
  5×1 Layer array with layers:

     1   'in'       Sequence Input         Sequence input with 1 dimensions
     2   'emb'      Word Embedding Layer   Word embedding layer with 25 dimensions and 421 unique words
     3   'bilstm'   BiLSTM                 BiLSTM with 40 hidden units
     4   'fc'       Fully Connected        4 fully connected layer
     5   'sm'       Softmax                softmax

Преобразуйте массив слоя в график слоев и создайте dlnetwork объект.

lgraph = layerGraph(layers);
dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [5×1 nnet.cnn.layer.Layer]
    Connections: [4×2 table]
     Learnables: [6×3 table]
          State: [2×3 table]
     InputNames: {'in'}
    OutputNames: {'sm'}

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет dlnetwork объект, мини-пакет входных данных с соответствующими метками, и возвращают градиенты потери относительно настраиваемых параметров в сети и соответствующей потери.

Задайте опции обучения

Обучайтесь в течение 30 эпох с мини-пакетным размером 16.

numEpochs = 30;
miniBatchSize = 16;

Задайте опции для оптимизации Адама. Укажите, что начальная буква изучает уровень 0,001 с затуханием 0,01, фактор затухания градиента 0.9 и фактор затухания градиента в квадрате 0.999.

initialLearnRate = 0.001;
decay = 0.01;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);

lineLossValidation = animatedline( ...
    'LineStyle','--', ...
    'Marker','o', ...
    'MarkerFaceColor','black');

ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметры для Адама.

trailingAvg = [];
trailingAvgSq = [];

Создайте minibatchqueue возразите, что процессы и управляют мини-пакетами данных. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы преобразовать документы последовательностям и одногорячий кодируют метки. Чтобы передать кодирование слова мини-пакету, создайте анонимную функцию, которая берет два входных параметров.

  • Формат предикторы с размерностью маркирует 'BTC' (пакет, время, канал). minibatchqueue объект, по умолчанию, преобразует данные в dlarray объекты с базовым типом single.

  • Обучайтесь на графическом процессоре, если вы доступны. minibatchqueue объект, по умолчанию, преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain, ...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ...
    'MiniBatchFormat',{'BTC',''});

Создайте minibatchqueue объект для данных о валидации с помощью тех же опций и также задает, чтобы возвратить частичные мини-пакеты.

mbqValidation = minibatchqueue(dsValidation, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ...
    'MiniBatchFormat',{'BTC',''}, ...
    'PartialMiniBatch','return');

Обучите сеть. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. В конце каждой итерации отобразите прогресс обучения. В конце каждой эпохи проверьте сеть с помощью данных о валидации.

Для каждого мини-пакета:

  • Преобразуйте документы последовательностям целых чисел, и одногорячий кодируют метки.

  • Преобразуйте данные в dlarray объекты с базовым одним типом и указывают, что размерность маркирует 'BTC' (пакет, время, канал).

  • Для обучения графического процессора преобразуйте в gpuArray объекты.

  • Оцените градиенты модели, состояние и потерю с помощью dlfeval и modelGradients функционируйте и обновите сетевое состояние.

  • Определите скорость обучения для основанного на времени расписания скорости обучения затухания.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Обновите учебный график.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the Adam optimizer.
        [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet, gradients, ...
            trailingAvg, trailingAvgSq, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
        
        % Validate network.
        if iteration == 1 || ~hasdata(mbq)
            % Validation predictions.
            [~,lossValidation] = modelPredictions(dlnet,mbqValidation,classes);
            
            % Update plot.
            addpoints(lineLossValidation,iteration,lossValidation)
            drawnow
        end
    end
end

Тестовая модель

Протестируйте точность классификации модели путем сравнения предсказаний на наборе валидации с истинными метками.

Классифицируйте данные о валидации с помощью modelPredictions функция, перечисленная в конце примера.

dlYPred = modelPredictions(dlnet,mbqValidation,classes);
YPred = onehotdecode(dlYPred,classes,1)';

Оцените точность классификации.

accuracy = mean(YPred == YValidation)
accuracy = 0.9167

Предскажите Используя новые данные

Классифицируйте тип события трех новых отчетов. Создайте массив строк, содержащий новые отчеты.

reportsNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Предварительно обработайте текстовые данные с помощью шагов предварительной обработки в качестве учебных материалов.

documentsNew = preprocessText(reportsNew);
dsNew = arrayDatastore(documentsNew,'OutputType','cell');

Создайте minibatchqueue возразите, что процессы и управляют мини-пакетами данных. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatchPredictors (заданный в конце этого примера), чтобы преобразовать документы последовательностям. Эта функция предварительной обработки не требует данных о метке. Чтобы передать кодирование слова мини-пакету, создайте анонимную функцию, которая берет вход того только.

  • Формат предикторы с размерностью маркирует 'BTC' (пакет, время, канал). minibatchqueue объект, по умолчанию, преобразует данные в dlarray объекты с базовым типом single.

  • Чтобы сделать предсказания для всех наблюдений, возвратите любые частичные мини-пакеты.

mbqNew = minibatchqueue(dsNew, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@(X) preprocessMiniBatchPredictors(X,enc), ...
    'MiniBatchFormat','BTC', ...
    'PartialMiniBatch','return');

Классифицируйте текстовые данные с помощью modelPredictions функция, перечисленная в конце примера и, находит классы с самыми высокими баллами.

dlYPred = modelPredictions(dlnet,mbqNew,classes);
YPred = onehotdecode(dlYPred,classes,1)'
YPred = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

Текст, предварительно обрабатывающий функцию

Функциональный preprocessText выполняет эти шаги:

  1. Маркируйте текст с помощью tokenizedDocument.

  2. Преобразуйте текст в нижний регистр с помощью lower.

  3. Сотрите пунктуацию с помощью erasePunctuation.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция преобразует мини-пакет документов последовательностям целых чисел, и одногорячий кодирует данные о метке.

function [X, Y] = preprocessMiniBatch(documentsCell,labelsCell,enc)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(documentsCell,enc);

% Extract labels from cell and concatenate.
Y = cat(1,labelsCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,2);

% Transpose the encoded labels to match the network output.
Y = Y';

end

Мини-пакетные предикторы, предварительно обрабатывающие функцию

preprocessMiniBatchPredictors функция преобразует мини-пакет документов последовательностям целых чисел.

function X = preprocessMiniBatchPredictors(documentsCell,enc)

% Extract documents from cell and concatenate.
documents = cat(4,documentsCell{1:end});

% Convert documents to sequences of integers.
X = doc2sequence(enc,documents);
X = cat(1,X{:});

end

Функция градиентов модели

modelGradients функционируйте берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствующей целью маркирует T и возвращает градиенты потери относительно настраиваемых параметров в dlnet, и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,T)

dlYPred = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция предсказаний модели

modelPredictions функционируйте берет dlnetwork объект dlnet, мини-пакетная очередь и выходные параметры предсказания модели путем итерации по мини-пакетам в очереди. Чтобы оценить данные о валидации, эта функция опционально вычисляет потерю, когда дали мини-пакетная очередь с двумя выходными параметрами.

function [dlYPred,loss] = modelPredictions(dlnet,mbq,classes)

% Initialize predictions.
numClasses = numel(classes);
outputCast = mbq.OutputCast{1};
dlYPred = dlarray(zeros(numClasses,0,outputCast),'CB');

% Reset mini-batch queue.
reset(mbq);

% For mini-batch queues with two ouputs, also compute the loss.
if mbq.NumOutputs == 1
    
    % Loop over mini-batches.
    while hasdata(mbq)
        
        % Make predictions.
        dlX = next(mbq);
        dlY = predict(dlnet,dlX);
        dlYPred = [dlYPred dlY];
    end
    
else
    % Initialize loss.
    numObservations = 0;
    loss = 0;
    
    % Loop over mini-batches.
    while hasdata(mbq)
        
        % Make predictions.
        [dlX,dlT] = next(mbq);
        dlY = predict(dlnet,dlX);
        dlYPred = [dlYPred dlY];
        
        % Calculate unnormalized loss.
        miniBatchSize = size(dlX,2);
        loss = loss + miniBatchSize * crossentropy(dlY, dlT);
        
        % Count observations.
        numObservations = numObservations + miniBatchSize;
    end
    
    % Normalize loss.
    loss = loss / numObservations;
    
    % Convert to double.
    loss = double(gather(extractdata(loss)));
end

end

Смотрите также

(Text Analytics Toolbox) | (Text Analytics Toolbox) | | (Text Analytics Toolbox) | | (Text Analytics Toolbox) | | |

Похожие темы