В этом примере показано, как классифицировать текстовые данные с помощью сети BiLSTM с глубоким обучением и двунаправленной долговременной памятью (BiLSTM) с пользовательским обучающим циклом.
При обучении сеть глубокого обучения с использованием trainNetwork функция, если trainingOptions не предоставляет необходимых опций (например, пользовательский график обучения), то можно определить собственный пользовательский цикл обучения с помощью автоматического дифференцирования. Для примера, показывающего, как классифицировать текстовые данные с помощью trainNetwork см. раздел Классификация текстовых данных с использованием глубокого обучения (панель инструментов глубокого обучения).
Этот пример обучает сеть классифицировать текстовые данные по расписанию скорости обучения затуханию на основе времени: для каждой итерации решатель использует скорость обучения, заданную t - итераций, α0 - начальная скорость обучения, а k - затухание.
Импортируйте данные производственных отчетов. Эти данные содержат помеченные текстовые описания заводских событий. Чтобы импортировать текстовые данные в виде строк, укажите тип текста 'string'.
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
Целью этого примера является классификация событий по метке в Category столбец. Чтобы разделить данные на классы, преобразуйте эти метки в категориальные.
data.Category = categorical(data.Category);
Просмотрите распределение классов в данных с помощью гистограммы.
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")

Следующим шагом является разделение его на наборы для обучения и проверки. Разбиение данных на раздел обучения и раздел удержания для проверки и тестирования. Укажите процент удержания равным 20%.
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);Извлеките текстовые данные и метки из секционированных таблиц.
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
Чтобы убедиться, что данные импортированы правильно, визуализируйте текстовые данные обучения с помощью облака слов.
figure
wordcloud(textDataTrain);
title("Training Data")
Просмотр количества классов.
classes = categories(YTrain); numClasses = numel(classes)
numClasses = 4
Создайте функцию, которая маркирует и предварительно обрабатывает текстовые данные. Функция preprocessText, перечисленное в конце примера, выполняет следующие шаги:
Маркировка текста с помощью tokenizedDocument.
Преобразование текста в нижний регистр с помощью lower.
Стереть пунктуацию с помощью erasePunctuation.
Предварительная обработка данных обучения и данных проверки с помощью preprocessText функция.
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
Просмотрите первые несколько предварительно обработанных учебных документов.
documentsTrain(1:5)
ans =
5×1 tokenizedDocument:
9 tokens: items are occasionally getting stuck in the scanner spools
10 tokens: loud rattling and banging sounds are coming from assembler pistons
5 tokens: fried capacitors in the assembler
4 tokens: mixer tripped the fuses
9 tokens: burst pipe in the constructing agent is spraying coolant
Создание единого хранилища данных, содержащего как документы, так и метки. arrayDatastore объекты, затем их объединение с помощью combine функция.
dsDocumentsTrain = arrayDatastore(documentsTrain,'OutputType','cell'); dsYTrain = arrayDatastore(YTrain,'OutputType','cell'); dsTrain = combine(dsDocumentsTrain,dsYTrain);
Создайте хранилище данных для данных проверки, используя те же шаги.
dsDocumentsValidation = arrayDatastore(documentsValidation,'OutputType','cell'); dsYValidation = arrayDatastore(YValidation,'OutputType','cell'); dsValidation = combine(dsDocumentsValidation,dsYValidation);
Для ввода документов в сеть BiLSTM используйте кодировку слов для преобразования документов в последовательности числовых индексов.
Чтобы создать кодировку слова, используйте wordEncoding функция.
enc = wordEncoding(documentsTrain)
enc =
wordEncoding with properties:
NumWords: 421
Vocabulary: [1×421 string]
Определите архитектуру сети BiLSTM. Чтобы ввести данные последовательности в сеть, включите входной уровень последовательности и установите размер ввода равным 1. Затем включают в себя слой встраивания слов размерности 25 и то же количество слов, что и кодирование слов. Затем включите уровень BiLSTM и установите число скрытых единиц равным 40. Чтобы использовать уровень BiLSTM для задачи классификации «последовательность-метка», установите режим вывода в значение 'last'. Наконец, добавьте полностью подключенный слой того же размера, что и количество классов, и слой softmax.
inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;
numWords = enc.NumWords;
layers = [
sequenceInputLayer(inputSize,'Name','in')
wordEmbeddingLayer(embeddingDimension,numWords,'Name','emb')
bilstmLayer(numHiddenUnits,'OutputMode','last','Name','bilstm')
fullyConnectedLayer(numClasses,'Name','fc')
softmaxLayer('Name','sm')]layers =
5×1 Layer array with layers:
1 'in' Sequence Input Sequence input with 1 dimensions
2 'emb' Word Embedding Layer Word embedding layer with 25 dimensions and 421 unique words
3 'bilstm' BiLSTM BiLSTM with 40 hidden units
4 'fc' Fully Connected 4 fully connected layer
5 'sm' Softmax softmax
Преобразование массива слоев в график слоев и создание dlnetwork объект.
lgraph = layerGraph(layers); dlnet = dlnetwork(lgraph)
dlnet =
dlnetwork with properties:
Layers: [5×1 nnet.cnn.layer.Layer]
Connections: [4×2 table]
Learnables: [6×3 table]
State: [2×3 table]
InputNames: {'in'}
OutputNames: {'sm'}
Создание функции modelGradients, перечисленных в конце примера, который принимает dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.
Поезд на 30 эпох с размером мини-партии 16.
numEpochs = 30; miniBatchSize = 16;
Укажите параметры оптимизации Adam. Укажите начальную скорость обучения 0,001 с распадом 0,01, коэффициентом градиентного распада 0,9 и коэффициентом градиентного распада 0,999 в квадрате.
initialLearnRate = 0.001; decay = 0.01; gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
Обучение модели с помощью пользовательского цикла обучения.
Инициализируйте график хода обучения.
figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); lineLossValidation = animatedline( ... 'LineStyle','--', ... 'Marker','o', ... 'MarkerFaceColor','black'); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
Инициализируйте параметры для Adam.
trailingAvg = []; trailingAvgSq = [];
Создать minibatchqueue объект, который обрабатывает мини-пакеты данных и управляет ими. Для каждой мини-партии:
Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определено в конце этого примера) для преобразования документов в последовательности и одноконтактного кодирования меток. Чтобы передать кодировку слова мини-пакету, создайте анонимную функцию, которая принимает два входа.
Форматирование предикторов с метками измерения 'BTC' (партия, время, канал). minibatchqueue по умолчанию преобразует данные в dlarray объекты с базовым типом single.
Обучение на GPU, если он доступен. minibatchqueue по умолчанию преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе.
mbq = minibatchqueue(dsTrain, ... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''});
Создать minibatchqueue объект для данных проверки с использованием тех же опций, а также укажите возврат частичных мини-пакетов.
mbqValidation = minibatchqueue(dsValidation, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn', @(X,Y) preprocessMiniBatch(X,Y,enc), ... 'MiniBatchFormat',{'BTC',''}, ... 'PartialMiniBatch','return');
Обучение сети. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. В конце каждой итерации просмотрите ход обучения. В конце каждой эпохи проверьте сеть с помощью данных проверки.
Для каждой мини-партии:
Преобразование документов в последовательности целых чисел и одноконтактное кодирование меток.
Преобразовать данные в dlarray объекты с одним базовым типом и указать метки размеров 'BTC' (партия, время, канал).
Для обучения GPU, конвертировать в gpuArray объекты.
Оценка градиентов, состояния и потерь модели с помощью dlfeval и modelGradients и обновить состояние сети.
Определите скорость обучения для графика скорости обучения с учетом времени.
Обновление параметров сети с помощью adamupdate функция.
Обновите график обучения.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) iteration = iteration + 1; % Read mini-batch of data. [dlX, dlY] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the Adam optimizer. [dlnet,trailingAvg,trailingAvgSq] = adamupdate(dlnet, gradients, ... trailingAvg, trailingAvgSq, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,loss) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow % Validate network. if iteration == 1 || ~hasdata(mbq) % Validation predictions. [~,lossValidation] = modelPredictions(dlnet,mbqValidation,classes); % Update plot. addpoints(lineLossValidation,iteration,lossValidation) drawnow end end end

Проверьте точность классификации модели путем сравнения прогнозов на наборе проверки с истинными метками.
Классифицировать данные проверки с помощью modelPredictions функция, перечисленная в конце примера.
dlYPred = modelPredictions(dlnet,mbqValidation,classes); YPred = onehotdecode(dlYPred,classes,1)';
Оцените точность классификации.
accuracy = mean(YPred == YValidation)
accuracy = 0.9167
Классифицируйте тип события для трех новых отчетов. Создайте строковый массив, содержащий новые отчеты.
reportsNew = [
"Coolant is pooling underneath sorter."
"Sorter blows fuses at start up."
"There are some very loud rattling sounds coming from the assembler."];Предварительная обработка текстовых данных с использованием шагов предварительной обработки в качестве учебных документов.
documentsNew = preprocessText(reportsNew); dsNew = arrayDatastore(documentsNew,'OutputType','cell');
Создать minibatchqueue объект, который обрабатывает мини-пакеты данных и управляет ими. Для каждой мини-партии:
Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatchPredictors (определено в конце этого примера) для преобразования документов в последовательности. Эта функция предварительной обработки не требует данных метки. Чтобы передать кодировку слова мини-пакету, создайте анонимную функцию, которая принимает только один ввод.
Форматирование предикторов с метками измерения 'BTC' (партия, время, канал). minibatchqueue по умолчанию преобразует данные в dlarray объекты с базовым типом single.
Чтобы сделать прогнозы для всех наблюдений, верните любые частичные мини-партии.
mbqNew = minibatchqueue(dsNew, ... 'MiniBatchSize',miniBatchSize, ... 'MiniBatchFcn',@(X) preprocessMiniBatchPredictors(X,enc), ... 'MiniBatchFormat','BTC', ... 'PartialMiniBatch','return');
Классифицировать текстовые данные с помощью modelPredictions функция, перечисленная в конце примера, и найдите классы с самыми высокими баллами.
dlYPred = modelPredictions(dlnet,mbqNew,classes); YPred = onehotdecode(dlYPred,classes,1)'
YPred = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
Функция preprocessText выполняет следующие шаги:
Маркировка текста с помощью tokenizedDocument.
Преобразование текста в нижний регистр с помощью lower.
Стереть пунктуацию с помощью erasePunctuation.
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
preprocessMiniBatch функция преобразует мини-пакет документов в последовательности целых чисел и одноступенчато кодирует данные метки.
function [X, Y] = preprocessMiniBatch(documentsCell,labelsCell,enc) % Preprocess predictors. X = preprocessMiniBatchPredictors(documentsCell,enc); % Extract labels from cell and concatenate. Y = cat(1,labelsCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,2); % Transpose the encoded labels to match the network output. Y = Y'; end
preprocessMiniBatchPredictors функция преобразует мини-пакет документов в последовательности целых чисел.
function X = preprocessMiniBatchPredictors(documentsCell,enc) % Extract documents from cell and concatenate. documents = cat(4,documentsCell{1:end}); % Convert documents to sequences of integers. X = doc2sequence(enc,documents); X = cat(1,X{:}); end
modelGradients функция принимает dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствующими целевыми метками T и возвращает градиенты потерь относительно обучаемых параметров в dlnetи потери. Для автоматического вычисления градиентов используйте dlgradient функция.
function [gradients,loss] = modelGradients(dlnet,dlX,T) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,T); gradients = dlgradient(loss,dlnet.Learnables); loss = double(gather(extractdata(loss))); end
modelPredictions функция принимает dlnetwork объект dlnet, очередь мини-пакетов и выводит прогнозы модели путем итерации по мини-пакетам в очереди. Для оценки данных проверки эта функция дополнительно вычисляет потери при наличии мини-пакетной очереди с двумя выходами.
function [dlYPred,loss] = modelPredictions(dlnet,mbq,classes) % Initialize predictions. numClasses = numel(classes); outputCast = mbq.OutputCast{1}; dlYPred = dlarray(zeros(numClasses,0,outputCast),'CB'); % Reset mini-batch queue. reset(mbq); % For mini-batch queues with two ouputs, also compute the loss. if mbq.NumOutputs == 1 % Loop over mini-batches. while hasdata(mbq) % Make predictions. dlX = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; end else % Initialize loss. numObservations = 0; loss = 0; % Loop over mini-batches. while hasdata(mbq) % Make predictions. [dlX,dlT] = next(mbq); dlY = predict(dlnet,dlX); dlYPred = [dlYPred dlY]; % Calculate unnormalized loss. miniBatchSize = size(dlX,2); loss = loss + miniBatchSize * crossentropy(dlY, dlT); % Count observations. numObservations = numObservations + miniBatchSize; end % Normalize loss. loss = loss / numObservations; % Convert to double. loss = double(gather(extractdata(loss))); end end
doc2sequence | tokenizedDocument | wordcloud | wordEmbeddingLayer | dlarray (инструментарий для глубокого обучения) | dlfeval (инструментарий для глубокого обучения) | dlgradient (инструментарий для глубокого обучения) | lstmLayer (инструментарий для глубокого обучения) | sequenceInputLayer (инструментарий для глубокого обучения)