Классификация текстовых данных с помощью пользовательского цикла обучения

В этом примере показано, как классифицировать текстовые данные с помощью сети глубокого обучения с двунаправленной длинной краткосрочной памятью (BiLSTM) с помощью пользовательского цикла обучения.

При обучении нейронной сети для глубокого обучения с помощью trainNetwork функция, если trainingOptions не предоставляет необходимые опции (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью автоматической дифференциации. Для примера, показывающего, как классифицировать текстовые данные с помощью trainNetwork функция, см. Классификация текстовых данных с использованием глубокого обучения (Deep Learning Toolbox).

Этот пример обучает сеть классифицировать текстовые данные с основанным на времени расписанием скорости обучения с распадом: для каждой итерации решатель использует скорость обучения, заданную как ρt=ρ01+kt, где t - число итерации, ρ0 является начальной скоростью обучения, и k является распадом.

Импорт данных

Импортируйте данные заводских отчетов. Эти данные содержат маркированные текстовые описания заводских событий. Чтобы импортировать текстовые данные как строки, задайте тип текста, который будет 'string'.

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Цель этого примера состоит в том, чтобы классифицировать события по меткам в Category столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальные.

data.Category = categorical(data.Category);

Просмотр распределения классов в данных с помощью гистограммы.

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

Следующий шаг - разбить его на наборы для обучения и валидации. Разделите данные на обучающий раздел и удерживаемый раздел для валидации и проверки. Задайте процент удержания 20%.

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

Извлеките текстовые данные и метки из секционированных таблиц.

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

Чтобы проверить, правильно ли вы импортировали данные, визуализируйте обучающие текстовые данные с помощью облака слов.

figure
wordcloud(textDataTrain);
title("Training Data")

Просмотрите количество классов.

classes = categories(YTrain);
numClasses = numel(classes)
numClasses = 4

Предварительная обработка текстовых данных

Создайте функцию, которая токенизирует и предварительно обрабатывает текстовые данные. Функция preprocessText, перечисленный в конце примера, выполняет следующие шаги:

  1. Токенизация текста с помощью tokenizedDocument.

  2. Преобразуйте текст в нижний регистр с помощью lower.

  3. Удалите пунктуацию с помощью erasePunctuation.

Предварительно обработайте обучающие данные и данные валидации с помощью preprocessText функция.

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

Просмотрите первые несколько предварительно обработанных обучающих документов.

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses
     9 tokens: burst pipe in the constructing agent is spraying coolant

Создайте один datastore, который содержит как документы, так и метки путем создания arrayDatastore объекты, затем объединение их с помощью combine функция.

dsDocumentsTrain = arrayDatastore(documentsTrain,'OutputType','cell');
dsYTrain = arrayDatastore(YTrain,'OutputType','cell');
dsTrain = combine(dsDocumentsTrain,dsYTrain);

Создайте datastore для данных валидации с помощью тех же шагов.

dsDocumentsValidation = arrayDatastore(documentsValidation,'OutputType','cell');
dsYValidation = arrayDatastore(YValidation,'OutputType','cell');
dsValidation = combine(dsDocumentsValidation,dsYValidation);

Создайте кодировку Word

Для ввода документов в сеть BiLSTM используйте кодировку слов, чтобы преобразовать документы в последовательности числовых индексов.

Чтобы создать кодировку слов, используйте wordEncoding функция.

enc = wordEncoding(documentsTrain)
enc = 
  wordEncoding with properties:

      NumWords: 421
    Vocabulary: [1×421 string]

Определение сети

Определите сетевую архитектуру BiLSTM. Чтобы ввести данные последовательности в сеть, включите входной слой последовательности и установите размер входа равным 1. Затем включает в себя слой встраивания слов размерности 25 и такое же количество слов, как и в кодировании слов. Затем включите слой BiLSTM и установите количество скрытых модулей 40. Чтобы использовать слой BiLSTM для задачи классификации от последовательности до метки, установите режим выхода равным 'last'. Наконец, добавьте полносвязный слой с таким же размером, как и количество классов, и слой softmax.

inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;

numWords = enc.NumWords;

layers = [
    sequenceInputLayer(inputSize,'Name','in')
    wordEmbeddingLayer(embeddingDimension,numWords,'Name','emb')
    bilstmLayer(numHiddenUnits,'OutputMode','last','Name','bilstm')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','sm')]
layers = 
  5×1 Layer array with layers:

     1   'in'       Sequence Input         Sequence input with 1 dimensions
     2   'emb'      Word Embedding Layer   Word embedding layer with 25 dimensions and 421 unique words
     3   'bilstm'   BiLSTM                 BiLSTM with 40 hidden units
     4   'fc'       Fully Connected        4 fully connected layer
     5   'sm'       Softmax                softmax

Преобразуйте массив слоев в график слоев и создайте dlnetwork объект.

lgraph = layerGraph(layers);
dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [5×1 nnet.cnn.layer.Layer]
    Connections: [4×2 table]
     Learnables: [6×3 table]
          State: [2×3 table]
     InputNames: {'in'}
    OutputNames: {'sm'}

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в конце примера, который принимает dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.

Настройка опций обучения

Обучайте на 30 эпох с мини-партией размером 16.

numEpochs = 30;
miniBatchSize = 16;

Задайте опции для оптимизации Adam. Задайте начальную скорость обучения 0,001 с распадом 0,01, коэффициент градиентного распада 0,9 и квадратный коэффициент градиентного распада 0,999.

initialLearnRate = 0.001;
decay = 0.01;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Обучите модель

Обучите модель с помощью пользовательского цикла обучения.

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);

lineLossValidation = animatedline( ...
    'LineStyle','--', ...
    'Marker','o', ...
    'MarkerFaceColor','black');

ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметры для Adam.

trailingAvg = [];
trailingAvgSq = [];

Создайте minibatchqueue объект, который обрабатывает и управляет мини-пакетами данных. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера) для преобразования документов в последовательности и одноразового кодирования меток. Чтобы передать кодировку слова в мини-пакет, создайте анонимную функцию, которая принимает два входов.

  • Форматируйте предикторы с метками размерности 'BTC' (пакет, время, канал). The minibatchqueue объект по умолчанию преобразует данные в dlarray объекты с базовым типом single.

  • Обучите на графическом процессоре, если он доступен. The minibatchqueue объект по умолчанию преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах см. раздел.

mbq = minibatchqueue(dsTrain, ...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ...
    'MiniBatchFormat',{'BTC',''});

Создайте minibatchqueue объект для данных валидации, используя те же опции, а также задающий, чтобы вернуть частичные мини-пакеты.

mbqValidation = minibatchqueue(dsValidation, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ...
    'MiniBatchFormat',{'BTC',''}, ...
    'PartialMiniBatch','return');

Обучите сеть. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. В конце каждой итерации отобразите процесс обучения. В конце каждой эпохи проверьте сеть с помощью данных валидации.

Для каждого мини-пакета:

  • Преобразуйте документы в последовательности целых чисел и закодируйте метки с одним контактом.

  • Преобразуйте данные в dlarray объекты с базовым типом одинарные и задают метки размерностей 'BTC' (пакет, время, канал).

  • Для обучения графический процессор преобразуйте в gpuArray объекты.

  • Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients функционировать и обновлять состояние сети.

  • Определите скорость обучения для основанного на времени расписания скорости обучения с распадом.

  • Обновляйте параметры сети с помощью adamupdate функция.

  • Обновите график обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the Adam optimizer.
        [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet, gradients, ...
            trailingAvg, trailingAvgSq, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
        
        % Validate network.
        if iteration == 1 || ~hasdata(mbq)
            % Validation predictions.
            [~,lossValidation] = modelPredictions(dlnet,mbqValidation,classes);
            
            % Update plot.
            addpoints(lineLossValidation,iteration,lossValidation)
            drawnow
        end
    end
end

Экспериментальная модель

Протестируйте классификационную точность модели путем сравнения предсказаний на наборе валидации с истинными метками.

Классифицируйте данные валидации с помощью modelPredictions функции, перечисленной в конце примера.

dlYPred = modelPredictions(dlnet,mbqValidation,classes);
YPred = onehotdecode(dlYPred,classes,1)';

Оцените точность классификации.

accuracy = mean(YPred == YValidation)
accuracy = 0.9167

Предсказание с использованием новых данных

Классифицируйте тип события трех новых отчетов. Создайте строковые массивы, содержащий новые отчеты.

reportsNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

Предварительно обработайте текстовые данные, используя шаги предварительной обработки в качестве обучающих документов.

documentsNew = preprocessText(reportsNew);
dsNew = arrayDatastore(documentsNew,'OutputType','cell');

Создайте minibatchqueue объект, который обрабатывает и управляет мини-пакетами данных. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatchPredictors (определено в конце этого примера) для преобразования документов в последовательности. Эта функция предварительной обработки не требует данных о метке. Чтобы передать кодировку слова в мини-пакет, создайте анонимную функцию, которая принимает только один вход.

  • Форматируйте предикторы с метками размерности 'BTC' (пакет, время, канал). The minibatchqueue объект по умолчанию преобразует данные в dlarray объекты с базовым типом single.

  • Чтобы сделать предсказания для всех наблюдений, верните любые частичные мини-пакеты.

mbqNew = minibatchqueue(dsNew, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@(X) preprocessMiniBatchPredictors(X,enc), ...
    'MiniBatchFormat','BTC', ...
    'PartialMiniBatch','return');

Классифицировать текстовые данные можно используя modelPredictions функция, перечисленная в конце примера, и найти классы с самыми высокими счетами.

dlYPred = modelPredictions(dlnet,mbqNew,classes);
YPred = onehotdecode(dlYPred,classes,1)'
YPred = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

Функция предварительной обработки текста

Функция preprocessText выполняет следующие шаги:

  1. Токенизация текста с помощью tokenizedDocument.

  2. Преобразуйте текст в нижний регистр с помощью lower.

  3. Удалите пунктуацию с помощью erasePunctuation.

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция преобразует мини-пакет документов в последовательности целых чисел и данные метки с одним горячим кодом.

function [X, Y] = preprocessMiniBatch(documentsCell,labelsCell,enc)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(documentsCell,enc);

% Extract labels from cell and concatenate.
Y = cat(1,labelsCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,2);

% Transpose the encoded labels to match the network output.
Y = Y';

end

Функция предварительной обработки мини-пакетных предикторов

The preprocessMiniBatchPredictors функция преобразует мини-пакет документов в последовательности целых чисел.

function X = preprocessMiniBatchPredictors(documentsCell,enc)

% Extract documents from cell and concatenate.
documents = cat(4,documentsCell{1:end});

% Convert documents to sequences of integers.
X = doc2sequence(enc,documents);
X = cat(1,X{:});

end

Функция градиентов модели

The modelGradients функция принимает dlnetwork dlnet объектамини-пакет входных данных dlX с соответствующими целевыми метками T и возвращает градиенты потерь относительно настраиваемых параметров в dlnetи потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,loss] = modelGradients(dlnet,dlX,T)

dlYPred = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция предсказаний модели

The modelPredictions функция принимает dlnetwork dlnet объекта, мини-очередь пакетов и выводит предсказания модели путем итерации по мини-пакетам в очереди. Чтобы оценить данные валидации, эта функция опционально вычисляет потерю при задании мини-очереди пакетов с двумя выходами.

function [dlYPred,loss] = modelPredictions(dlnet,mbq,classes)

% Initialize predictions.
numClasses = numel(classes);
outputCast = mbq.OutputCast{1};
dlYPred = dlarray(zeros(numClasses,0,outputCast),'CB');

% Reset mini-batch queue.
reset(mbq);

% For mini-batch queues with two ouputs, also compute the loss.
if mbq.NumOutputs == 1
    
    % Loop over mini-batches.
    while hasdata(mbq)
        
        % Make predictions.
        dlX = next(mbq);
        dlY = predict(dlnet,dlX);
        dlYPred = [dlYPred dlY];
    end
    
else
    % Initialize loss.
    numObservations = 0;
    loss = 0;
    
    % Loop over mini-batches.
    while hasdata(mbq)
        
        % Make predictions.
        [dlX,dlT] = next(mbq);
        dlY = predict(dlnet,dlX);
        dlYPred = [dlYPred dlY];
        
        % Calculate unnormalized loss.
        miniBatchSize = size(dlX,2);
        loss = loss + miniBatchSize * crossentropy(dlY, dlT);
        
        % Count observations.
        numObservations = numObservations + miniBatchSize;
    end
    
    % Normalize loss.
    loss = loss / numObservations;
    
    % Convert to double.
    loss = double(gather(extractdata(loss)));
end

end

См. также

| | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы