В этом примере показано, как классифицировать текстовые данные с помощью сети глубокого обучения с двунаправленной длинной краткосрочной памятью (BiLSTM) с помощью пользовательского цикла обучения.
При обучении нейронной сети для глубокого обучения с помощью trainNetwork
функция, если trainingOptions
не предоставляет необходимые опции (для примера, пользовательское расписание скорости обучения), тогда можно задать свой собственный пользовательский цикл обучения с помощью автоматической дифференциации. Для примера, показывающего, как классифицировать текстовые данные с помощью trainNetwork
функция, см. Классификация текстовых данных с использованием глубокого обучения (Deep Learning Toolbox).
Этот пример обучает сеть классифицировать текстовые данные с основанным на времени расписанием скорости обучения с распадом: для каждой итерации решатель использует скорость обучения, заданную как , где t - число итерации, является начальной скоростью обучения, и k является распадом.
Импортируйте данные заводских отчетов. Эти данные содержат маркированные текстовые описания заводских событий. Чтобы импортировать текстовые данные как строки, задайте тип текста, который будет 'string'
.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Цель этого примера состоит в том, чтобы классифицировать события по меткам в Category
столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальные.
data.Category = categorical(data.Category);
Просмотр распределения классов в данных с помощью гистограммы.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
Следующий шаг - разбить его на наборы для обучения и валидации. Разделите данные на обучающий раздел и удерживаемый раздел для валидации и проверки. Задайте процент удержания 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
Извлеките текстовые данные и метки из секционированных таблиц.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
Чтобы проверить, правильно ли вы импортировали данные, визуализируйте обучающие текстовые данные с помощью облака слов.
figure
wordcloud(textDataTrain);
title("Training Data")
Просмотрите количество классов.
classes = categories(YTrain); numClasses = numel(classes)
numClasses = 4
Создайте функцию, которая токенизирует и предварительно обрабатывает текстовые данные. Функция preprocessText
, перечисленный в конце примера, выполняет следующие шаги:
Токенизация текста с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Удалите пунктуацию с помощью erasePunctuation
.
Предварительно обработайте обучающие данные и данные валидации с помощью preprocessText
функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных обучающих документов.
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses 9 tokens: burst pipe in the constructing agent is spraying coolant
Создайте один datastore, который содержит как документы, так и метки путем создания arrayDatastore
объекты, затем объединение их с помощью combine
функция.
dsDocumentsTrain = arrayDatastore(documentsTrain,'OutputType','cell'); dsYTrain = arrayDatastore(YTrain,'OutputType','cell'); dsTrain = combine(dsDocumentsTrain,dsYTrain);
Создайте datastore для данных валидации с помощью тех же шагов.
dsDocumentsValidation = arrayDatastore(documentsValidation,'OutputType','cell'); dsYValidation = arrayDatastore(YValidation,'OutputType','cell'); dsValidation = combine(dsDocumentsValidation,dsYValidation);
Для ввода документов в сеть BiLSTM используйте кодировку слов, чтобы преобразовать документы в последовательности числовых индексов.
Чтобы создать кодировку слов, используйте wordEncoding
функция.
enc = wordEncoding(documentsTrain)
enc = wordEncoding with properties: NumWords: 421 Vocabulary: [1×421 string]
Определите сетевую архитектуру BiLSTM. Чтобы ввести данные последовательности в сеть, включите входной слой последовательности и установите размер входа равным 1. Затем включает в себя слой встраивания слов размерности 25 и такое же количество слов, как и в кодировании слов. Затем включите слой BiLSTM и установите количество скрытых модулей 40. Чтобы использовать слой BiLSTM для задачи классификации от последовательности до метки, установите режим выхода равным 'last'
. Наконец, добавьте полносвязный слой с таким же размером, как и количество классов, и слой softmax.
inputSize = 1; embeddingDimension = 25; numHiddenUnits = 40; numWords = enc.NumWords; layers = [ sequenceInputLayer(inputSize,'Name','in') wordEmbeddingLayer(embeddingDimension,numWords,'Name','emb') bilstmLayer(numHiddenUnits,'OutputMode','last','Name','bilstm') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','sm')]
layers = 5×1 Layer array with layers: 1 'in' Sequence Input Sequence input with 1 dimensions 2 'emb' Word Embedding Layer Word embedding layer with 25 dimensions and 421 unique words 3 'bilstm' BiLSTM BiLSTM with 40 hidden units 4 'fc' Fully Connected 4 fully connected layer 5 'sm' Softmax softmax
Преобразуйте массив слоев в график слоев и создайте dlnetwork
объект.
lgraph = layerGraph(layers); dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [5×1 nnet.cnn.layer.Layer] Connections: [4×2 table] Learnables: [6×3 table] State: [2×3 table] InputNames: {'in'} OutputNames: {'sm'}
Создайте функцию modelGradients
, перечисленный в конце примера, который принимает dlnetwork
объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.
Обучайте на 30 эпох с мини-партией размером 16.
numEpochs = 30; miniBatchSize = 16;
Задайте опции для оптимизации Adam. Задайте начальную скорость обучения 0,001 с распадом 0,01, коэффициент градиентного распада 0,9 и квадратный коэффициент градиентного распада 0,999.
initialLearnRate = 0.001; decay = 0.01; gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Обучите модель с помощью пользовательского цикла обучения.
Инициализируйте график процесса обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметры для Adam.
trailingAvg = []; trailingAvgSq = [];
Создайте minibatchqueue
объект, который обрабатывает и управляет мини-пакетами данных. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch
(определено в конце этого примера) для преобразования документов в последовательности и одноразового кодирования меток. Чтобы передать кодировку слова в мини-пакет, создайте анонимную функцию, которая принимает два входов.
Форматируйте предикторы с метками размерности 'BTC'
(пакет, время, канал). The minibatchqueue
объект по умолчанию преобразует данные в dlarray
объекты с базовым типом single
.
Обучите на графическом процессоре, если он доступен. The minibatchqueue
объект по умолчанию преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах см. раздел.
mbq = minibatchqueue(dsTrain, ... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''});
Создайте minibatchqueue
объект для данных валидации, используя те же опции, а также задающий, чтобы вернуть частичные мини-пакеты.
mbqValidation = minibatchqueue(dsValidation, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''}, ... 'PartialMiniBatch','return');
Обучите сеть. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. В конце каждой итерации отобразите процесс обучения. В конце каждой эпохи проверьте сеть с помощью данных валидации.
Для каждого мини-пакета:
Преобразуйте документы в последовательности целых чисел и закодируйте метки с одним контактом.
Преобразуйте данные в dlarray
объекты с базовым типом одинарные и задают метки размерностей 'BTC'
(пакет, время, канал).
Для обучения графический процессор преобразуйте в gpuArray
объекты.
Оцените градиенты модели, состояние и потери с помощью dlfeval
и modelGradients
функционировать и обновлять состояние сети.
Определите скорость обучения для основанного на времени расписания скорости обучения с распадом.
Обновляйте параметры сети с помощью adamupdate
функция.
Обновите график обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the Adam optimizer. [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet, gradients, ... trailingAvg, trailingAvgSq, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow % Validate network. if iteration == 1 || ~hasdata(mbq) % Validation predictions. [~,lossValidation] = modelPredictions(dlnet,mbqValidation,classes); % Update plot. addpoints(lineLossValidation,iteration,lossValidation) drawnow end end end
Протестируйте классификационную точность модели путем сравнения предсказаний на наборе валидации с истинными метками.
Классифицируйте данные валидации с помощью modelPredictions
функции, перечисленной в конце примера.
dlYPred = modelPredictions(dlnet,mbqValidation,classes); YPred = onehotdecode(dlYPred,classes,1)';
Оцените точность классификации.
accuracy = mean(YPred == YValidation)
accuracy = 0.9167
Классифицируйте тип события трех новых отчетов. Создайте строковые массивы, содержащий новые отчеты.
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Предварительно обработайте текстовые данные, используя шаги предварительной обработки в качестве обучающих документов.
documentsNew = preprocessText(reportsNew); dsNew = arrayDatastore(documentsNew,'OutputType','cell');
Создайте minibatchqueue
объект, который обрабатывает и управляет мини-пакетами данных. Для каждого мини-пакета:
Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatchPredictors
(определено в конце этого примера) для преобразования документов в последовательности. Эта функция предварительной обработки не требует данных о метке. Чтобы передать кодировку слова в мини-пакет, создайте анонимную функцию, которая принимает только один вход.
Форматируйте предикторы с метками размерности 'BTC'
(пакет, время, канал). The minibatchqueue
объект по умолчанию преобразует данные в dlarray
объекты с базовым типом single
.
Чтобы сделать предсказания для всех наблюдений, верните любые частичные мини-пакеты.
mbqNew = minibatchqueue(dsNew, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@(X) preprocessMiniBatchPredictors(X,enc), ... 'MiniBatchFormat','BTC', ... 'PartialMiniBatch','return');
Классифицировать текстовые данные можно используя modelPredictions
функция, перечисленная в конце примера, и найти классы с самыми высокими счетами.
dlYPred = modelPredictions(dlnet,mbqNew,classes); YPred = onehotdecode(dlYPred,classes,1)'
YPred = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Функция preprocessText
выполняет следующие шаги:
Токенизация текста с помощью tokenizedDocument
.
Преобразуйте текст в нижний регистр с помощью lower
.
Удалите пунктуацию с помощью erasePunctuation
.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
The preprocessMiniBatch
функция преобразует мини-пакет документов в последовательности целых чисел и данные метки с одним горячим кодом.
function [X, Y] = preprocessMiniBatch(documentsCell,labelsCell,enc) % Preprocess predictors. X = preprocessMiniBatchPredictors(documentsCell,enc); % Extract labels from cell and concatenate. Y = cat(1,labelsCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,2); % Transpose the encoded labels to match the network output. Y = Y'; end
The preprocessMiniBatchPredictors
функция преобразует мини-пакет документов в последовательности целых чисел.
function X = preprocessMiniBatchPredictors(documentsCell,enc) % Extract documents from cell and concatenate. documents = cat(4,documentsCell{1:end}); % Convert documents to sequences of integers. X = doc2sequence(enc,documents); X = cat(1,X{:}); end
The modelGradients
функция принимает dlnetwork
dlnet объекта
мини-пакет входных данных dlX
с соответствующими целевыми метками T
и возвращает градиенты потерь относительно настраиваемых параметров в dlnet
и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,T) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,T); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
The modelPredictions
функция принимает dlnetwork
dlnet объекта
, мини-очередь пакетов и выводит предсказания модели путем итерации по мини-пакетам в очереди. Чтобы оценить данные валидации, эта функция опционально вычисляет потерю при задании мини-очереди пакетов с двумя выходами.
function [dlYPred,loss] = modelPredictions(dlnet,mbq,classes) % Initialize predictions. numClasses = numel(classes); outputCast = mbq.OutputCast{1}; dlYPred = dlarray(zeros(numClasses,0,outputCast),'CB'); % Reset mini-batch queue. reset(mbq); % For mini-batch queues with two ouputs, also compute the loss. if mbq.NumOutputs == 1 % Loop over mini-batches. while hasdata(mbq) % Make predictions. dlX = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; end else % Initialize loss. numObservations = 0; loss = 0; % Loop over mini-batches. while hasdata(mbq) % Make predictions. [dlX,dlT] = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; % Calculate unnormalized loss. miniBatchSize = size(dlX,2); loss = loss + miniBatchSize * crossentropy(dlY, dlT); % Count observations. numObservations = numObservations + miniBatchSize; end % Normalize loss. loss = loss / numObservations; % Convert to double. loss = double(gather(extractdata(loss))); end end
doc2sequence
| tokenizedDocument
| wordcloud
| wordEmbeddingLayer
| dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | dlgradient
(Deep Learning Toolbox) | lstmLayer
(Deep Learning Toolbox) | sequenceInputLayer
(Deep Learning Toolbox)