В этом примере показано, как классифицировать текстовые данные с помощью глубокого обучения двунаправленная длинная краткосрочная сеть (BiLSTM) памяти с пользовательским учебным циклом.
Когда обучение нейронная сеть для глубокого обучения с помощью trainNetwork
функция, если trainingOptions
не предоставляет возможности, в которых вы нуждаетесь (например, пользовательское расписание скорости обучения), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования. Для примера, показывающего, как классифицировать текстовые данные с помощью trainNetwork
функционируйте, смотрите, Классифицируют текстовые Данные Используя Глубокое обучение (Deep Learning Toolbox).
Этот пример обучает сеть, чтобы классифицировать текстовые данные с основанным на времени расписанием скорости обучения затухания: для каждой итерации решатель использует скорость обучения, данную , где t является номером итерации, начальная скорость обучения, и k является затуханием.
Импортируйте данные об отчетах фабрики. Эти данные содержат помеченные текстовые описания событий фабрики. Чтобы импортировать текстовые данные как строки, задайте тип текста, чтобы быть 'string'
.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Цель этого примера состоит в том, чтобы классифицировать события меткой в Category
столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальный.
data.Category = categorical(data.Category);
Просмотрите распределение классов в данных с помощью гистограммы.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
Следующий шаг должен разделить его в наборы для обучения и валидации. Разделите данные в учебный раздел и протянутый раздел для валидации и тестирования. Задайте процент затяжки, чтобы быть 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
Извлеките текстовые данные и метки из разделенных таблиц.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
Чтобы проверять, что вы импортировали данные правильно, визуализируйте учебные текстовые данные с помощью облака слова.
figure
wordcloud(textDataTrain);
title("Training Data")
Просмотрите количество классов.
classes = categories(YTrain); numClasses = numel(classes)
numClasses = 4
Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функциональный preprocessText
, перечисленный в конце примера, выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
Предварительно обработайте обучающие данные и данные о валидации с помощью preprocessText
функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных учебных материалов.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses 9 tokens: burst pipe in the constructing agent is spraying coolant
Создайте один datastore, который содержит и документы и метки путем создания arrayDatastore
объекты, затем комбинируя их использующий combine
функция.
dsDocumentsTrain = arrayDatastore(documentsTrain,'OutputType','cell'); dsYTrain = arrayDatastore(YTrain,'OutputType','cell'); dsTrain = combine(dsDocumentsTrain,dsYTrain);
Создайте datastore для данных о валидации с помощью тех же шагов.
dsDocumentsValidation = arrayDatastore(documentsValidation,'OutputType','cell'); dsYValidation = arrayDatastore(YValidation,'OutputType','cell'); dsValidation = combine(dsDocumentsValidation,dsYValidation);
Чтобы ввести документы в сеть BiLSTM, используйте кодирование слова, чтобы преобразовать документы в последовательности числовых индексов.
Чтобы создать кодирование слова, используйте wordEncoding
функция.
enc = wordEncoding(documentsTrain)
enc = wordEncoding with properties: NumWords: 421 Vocabulary: [1×421 string]
Задайте архитектуру сети BiLSTM. Чтобы ввести данные о последовательности в сеть, включайте входной слой последовательности и установите входной размер на 1. Затем включайте слой встраивания слова размерности 25 и то же количество слов как кодирование слова. Затем включайте слой BiLSTM и определите номер скрытых модулей к 40. Чтобы использовать слой BiLSTM для проблемы классификации последовательностей к метке, установите режим вывода на 'last'
. Наконец, добавьте полносвязный слой с тем же размером как количество классов и softmax слой.
inputSize = 1; embeddingDimension = 25; numHiddenUnits = 40; numWords = enc.NumWords; layers = [ sequenceInputLayer(inputSize,'Name','in') wordEmbeddingLayer(embeddingDimension,numWords,'Name','emb') bilstmLayer(numHiddenUnits,'OutputMode','last','Name','bilstm') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','sm')]
layers = 5×1 Layer array with layers: 1 'in' Sequence Input Sequence input with 1 dimensions 2 'emb' Word Embedding Layer Word embedding layer with 25 dimensions and 421 unique words 3 'bilstm' BiLSTM BiLSTM with 40 hidden units 4 'fc' Fully Connected 4 fully connected layer 5 'sm' Softmax softmax
Преобразуйте массив слоя в график слоев и создайте dlnetwork
объект.
lgraph = layerGraph(layers); dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [5×1 nnet.cnn.layer.Layer] Connections: [4×2 table] Learnables: [6×3 table] State: [2×3 table] InputNames: {'in'} OutputNames: {'sm'}
Создайте функциональный modelGradients
, перечисленный в конце примера, который берет dlnetwork
объект, мини-пакет входных данных с соответствующими метками, и возвращают градиенты потери относительно настраиваемых параметров в сети и соответствующей потери.
Обучайтесь в течение 30 эпох с мини-пакетным размером 16.
numEpochs = 30; miniBatchSize = 16;
Задайте опции для оптимизации Адама. Укажите, что начальная буква изучает уровень 0,001 с затуханием 0,01, фактор затухания градиента 0.9 и фактор затухания градиента в квадрате 0.999.
initialLearnRate = 0.001; decay = 0.01; gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Обучите модель с помощью пользовательского учебного цикла.
Инициализируйте график процесса обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметры для Адама.
trailingAvg = []; trailingAvgSq = [];
Создайте minibatchqueue
возразите, что процессы и управляют мини-пакетами данных. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера), чтобы преобразовать документы последовательностям и одногорячий кодируют метки. Чтобы передать кодирование слова мини-пакету, создайте анонимную функцию, которая берет два входных параметров.
Формат предикторы с размерностью маркирует 'BTC'
(пакет, время, канал). minibatchqueue
объект, по умолчанию, преобразует данные в dlarray
объекты с базовым типом single
.
Обучайтесь на графическом процессоре, если вы доступны. minibatchqueue
объект, по умолчанию, преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите.
mbq = minibatchqueue(dsTrain, ... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''});
Создайте minibatchqueue
объект для данных о валидации с помощью тех же опций и также задает, чтобы возвратить частичные мини-пакеты.
mbqValidation = minibatchqueue(dsValidation, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''}, ... 'PartialMiniBatch','return');
Обучите сеть. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. В конце каждой итерации отобразите прогресс обучения. В конце каждой эпохи проверьте сеть с помощью данных о валидации.
Для каждого мини-пакета:
Преобразуйте документы последовательностям целых чисел, и одногорячий кодируют метки.
Преобразуйте данные в dlarray
объекты с базовым одним типом и указывают, что размерность маркирует 'BTC'
(пакет, время, канал).
Для обучения графического процессора преобразуйте в gpuArray
объекты.
Оцените градиенты модели, состояние и потерю с помощью dlfeval
и modelGradients
функционируйте и обновите сетевое состояние.
Определите скорость обучения для основанного на времени расписания скорости обучения затухания.
Обновите сетевые параметры с помощью adamupdate
функция.
Обновите учебный график.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the Adam optimizer. [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet, gradients, ... trailingAvg, trailingAvgSq, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow % Validate network. if iteration == 1 || ~hasdata(mbq) % Validation predictions. [~,lossValidation] = modelPredictions(dlnet,mbqValidation,classes); % Update plot. addpoints(lineLossValidation,iteration,lossValidation) drawnow end end end
Протестируйте точность классификации модели путем сравнения предсказаний на наборе валидации с истинными метками.
Классифицируйте данные о валидации с помощью modelPredictions
функция, перечисленная в конце примера.
dlYPred = modelPredictions(dlnet,mbqValidation,classes); YPred = onehotdecode(dlYPred,classes,1)';
Оцените точность классификации.
accuracy = mean(YPred == YValidation)
accuracy = 0.9167
Классифицируйте тип события трех новых отчетов. Создайте массив строк, содержащий новые отчеты.
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Предварительно обработайте текстовые данные с помощью шагов предварительной обработки в качестве учебных материалов.
documentsNew = preprocessText(reportsNew); dsNew = arrayDatastore(documentsNew,'OutputType','cell');
Создайте minibatchqueue
возразите, что процессы и управляют мини-пакетами данных. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatchPredictors
(заданный в конце этого примера), чтобы преобразовать документы последовательностям. Эта функция предварительной обработки не требует данных о метке. Чтобы передать кодирование слова мини-пакету, создайте анонимную функцию, которая берет вход того только.
Формат предикторы с размерностью маркирует 'BTC'
(пакет, время, канал). minibatchqueue
объект, по умолчанию, преобразует данные в dlarray
объекты с базовым типом single
.
Чтобы сделать предсказания для всех наблюдений, возвратите любые частичные мини-пакеты.
mbqNew = minibatchqueue(dsNew, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@(X) preprocessMiniBatchPredictors(X,enc), ... 'MiniBatchFormat','BTC', ... 'PartialMiniBatch','return');
Классифицируйте текстовые данные с помощью modelPredictions
функция, перечисленная в конце примера и, находит классы с самыми высокими баллами.
dlYPred = modelPredictions(dlnet,mbqNew,classes); YPred = onehotdecode(dlYPred,classes,1)'
YPred = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Функциональный preprocessText
выполняет эти шаги:
Маркируйте текст с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Сотрите пунктуацию с помощью erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
preprocessMiniBatch
функция преобразует мини-пакет документов последовательностям целых чисел, и одногорячий кодирует данные о метке.
function [X, Y] = preprocessMiniBatch(documentsCell,labelsCell,enc) % Preprocess predictors. X = preprocessMiniBatchPredictors(documentsCell,enc); % Extract labels from cell and concatenate. Y = cat(1,labelsCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,2); % Transpose the encoded labels to match the network output. Y = Y'; end
preprocessMiniBatchPredictors
функция преобразует мини-пакет документов последовательностям целых чисел.
function X = preprocessMiniBatchPredictors(documentsCell,enc) % Extract documents from cell and concatenate. documents = cat(4,documentsCell{1:end}); % Convert documents to sequences of integers. X = doc2sequence(enc,documents); X = cat(1,X{:}); end
modelGradients
функционируйте берет dlnetwork
объект dlnet
, мини-пакет входных данных dlX
с соответствующей целью маркирует T
и возвращает градиенты потери относительно настраиваемых параметров в dlnet
, и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,T) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,T); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
modelPredictions
функционируйте берет dlnetwork
объект dlnet
, мини-пакетная очередь и выходные параметры предсказания модели путем итерации по мини-пакетам в очереди. Чтобы оценить данные о валидации, эта функция опционально вычисляет потерю, когда дали мини-пакетная очередь с двумя выходными параметрами.
function [dlYPred,loss] = modelPredictions(dlnet,mbq,classes) % Initialize predictions. numClasses = numel(classes); outputCast = mbq.OutputCast{1}; dlYPred = dlarray(zeros(numClasses,0,outputCast),'CB'); % Reset mini-batch queue. reset(mbq); % For mini-batch queues with two ouputs, also compute the loss. if mbq.NumOutputs == 1 % Loop over mini-batches. while hasdata(mbq) % Make predictions. dlX = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; end else % Initialize loss. numObservations = 0; loss = 0; % Loop over mini-batches. while hasdata(mbq) % Make predictions. [dlX,dlT] = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; % Calculate unnormalized loss. miniBatchSize = size(dlX,2); loss = loss + miniBatchSize * crossentropy(dlY, dlT); % Count observations. numObservations = numObservations + miniBatchSize; end % Normalize loss. loss = loss / numObservations; % Convert to double. loss = double(gather(extractdata(loss))); end end
doc2sequence
| tokenizedDocument
| wordcloud
| wordEmbeddingLayer
| dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | dlgradient
(Deep Learning Toolbox) | lstmLayer
(Deep Learning Toolbox) | sequenceInputLayer
(Deep Learning Toolbox)