В этом примере показано, как классифицировать текстовые данные с помощью сверточной нейронной сети.
Чтобы классифицировать текстовые данные с помощью сверток, необходимо преобразовать текстовые данные в изображения. Для этого дополните или обрезайте наблюдения, чтобы иметь постоянную длину S, и преобразуйте документы в последовательности векторов слов длины C с помощью вложения слов. Затем можно представлять документ как 1-by-S-by-C изображение (изображение с высотой 1, шириной S и каналами C).
Чтобы преобразовать текстовые данные из файла CSV в изображения, создайте tabularTextDatastore
объект. Преобразуйте данные, считанные из tabularTextDatastore
объект к изображениям для глубокого обучения по вызову transform
с пользовательской функцией преобразования. The transformTextData
функция, перечисленная в конце примера, берёт данные, считанные из datastore, и предварительно обученное вложение слова и преобразует каждое наблюдение в массив векторов слов.
Этот пример обучает сеть с 1-D сверточными фильтрами различной ширины. Ширина каждого фильтра соответствует количеству слов, которые может видеть фильтр (длина n-грамма). Сеть имеет несколько ветвей сверточных слоев, поэтому она может использовать различные длины n-граммов.
Загрузите предварительно обученное вложение слова fastText. Эта функция требует Text Analytics Toolbox™ Model для fastText English 16 млрд Token Word Embedding пакет поддержки. Если этот пакет поддержки не установлен, то функция предоставляет ссылку на загрузку.
emb = fastTextWordEmbedding;
Создайте табличный текст datastore из данных в factoryReports.csv
. Считайте данные из "Description"
и "Category"
только столбцы.
filenameTrain = "factoryReports.csv"; textName = "Description"; labelName = "Category"; ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);
Предварительный просмотр datastore.
ttdsTrain.ReadSize = 8; preview(ttdsTrain)
ans=8×2 table
Description Category
_______________________________________________________________________ ______________________
{'Items are occasionally getting stuck in the scanner spools.' } {'Mechanical Failure'}
{'Loud rattling and banging sounds are coming from assembler pistons.'} {'Mechanical Failure'}
{'There are cuts to the power when starting the plant.' } {'Electronic Failure'}
{'Fried capacitors in the assembler.' } {'Electronic Failure'}
{'Mixer tripped the fuses.' } {'Electronic Failure'}
{'Burst pipe in the constructing agent is spraying coolant.' } {'Leak' }
{'A fuse is blown in the mixer.' } {'Electronic Failure'}
{'Things continue to tumble off of the belt.' } {'Mechanical Failure'}
Создайте пользовательскую функцию преобразования, которая преобразует данные, считанные из datastore, в таблицу, содержащую предикторы и ответы. The transformTextData
функция, перечисленная в конце примера, берёт данные, считанные из tabularTextDatastore
и возвращает таблицу предикторов и откликов. Предикторы 1-by- sequenceLength
-by-C массивы векторов слов, заданные словом embedding emb
, где C - размерность вложения. Ответы являются категориальными метками над классами в classNames
.
Считайте метки из обучающих данных с помощью readLabels
функция, перечисленная в конце примера, и найти уникальные имена классов.
labels = readLabels(ttdsTrain,labelName); classNames = unique(labels); numObservations = numel(labels);
Преобразуйте datastore с помощью transformTextData
и задайте длину последовательности 14.
sequenceLength = 14; tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTrain = TransformedDatastore with properties: UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore] SupportedOutputFormats: ["txt" "csv" "xlsx" "xls" "parquet" "parq" "png" "jpg" "jpeg" "tif" "tiff" "wav" "flac" "ogg" "mp4" "m4a"] Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)} IncludeInfo: 0
Предварительный просмотр преобразованного datastore. Предикторы являются 1-by-S-by-C массивами, где S - длина последовательности, а C - количество функций (размерность встраивания). Ответы являются категориальными метками.
preview(tdsTrain)
ans=8×2 table
Predictors Responses
_________________ __________________
{1×14×300 single} Mechanical Failure
{1×14×300 single} Mechanical Failure
{1×14×300 single} Electronic Failure
{1×14×300 single} Electronic Failure
{1×14×300 single} Electronic Failure
{1×14×300 single} Leak
{1×14×300 single} Electronic Failure
{1×14×300 single} Mechanical Failure
Определите сетевую архитектуру для задачи классификации.
Следующие шаги описывают сетевую архитектуру.
Задайте размер входа 1-by-S-by-C, где S - длина последовательности, а C - количество функций (размерность встраивания).
Для n-граммовых длин 2, 3, 4 и 5 создайте блоки слоев, содержащих сверточный слой, нормализацию партии ., слой ReLU, выпадающий слой и максимальный слой объединения.
Для каждого блока задайте 200 сверточных фильтров размера 1-by-N и области объединения размера 1-by-S, где N является длиной n-грамма.
Соедините слой входа с каждым блоком и соедините выходы блоков с помощью слоя конкатенации глубин.
Чтобы классифицировать выходы, включите полносвязный слой с выходным размером K, слой softmax и слой классификации, где K - количество классов.
Во-первых, в массиве слоев задайте входной слой, первый блок для unigrams, слой конкатенации глубин, полносвязный слой, слой softmax и слой классификации.
numFeatures = emb.Dimension; inputSize = [1 sequenceLength numFeatures]; numFilters = 200; ngramLengths = [2 3 4 5]; numBlocks = numel(ngramLengths); numClasses = numel(classNames);
Создайте график слоев, содержащий входной слой. Установите опцию нормализации равным 'none'
и имя слоя, для 'input'
.
layer = imageInputLayer(inputSize,'Normalization','none','Name','input'); lgraph = layerGraph(layer);
Для каждого из n-граммовых длин создайте блок свертки, нормализации партии ., ReLU, выпадающих и максимальных слоев объединения. Соедините каждый блок с входом слоем.
for j = 1:numBlocks N = ngramLengths(j); block = [ convolution2dLayer([1 N],numFilters,'Name',"conv"+N,'Padding','same') batchNormalizationLayer('Name',"bn"+N) reluLayer('Name',"relu"+N) dropoutLayer(0.2,'Name',"drop"+N) maxPooling2dLayer([1 sequenceLength],'Name',"max"+N)]; lgraph = addLayers(lgraph,block); lgraph = connectLayers(lgraph,'input',"conv"+N); end
Просмотрите сетевую архитектуру на графике.
figure
plot(lgraph)
title("Network Architecture")
Добавьте слой конкатенации глубин, полностью соединенный слой, слой softmax и слой классификации.
layers = [ depthConcatenationLayer(numBlocks,'Name','depth') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','soft') classificationLayer('Name','classification')]; lgraph = addLayers(lgraph,layers); figure plot(lgraph) title("Network Architecture")
Соедините максимальные слои объединения с слоем конкатенации глубин и просмотрите окончательную сетевую архитектуру на графике.
for j = 1:numBlocks N = ngramLengths(j); lgraph = connectLayers(lgraph,"max"+N,"depth/in"+j); end figure plot(lgraph) title("Network Architecture")
Задайте опции обучения:
Train с мини-партией размером 128.
Не тасуйте данные, потому что datastore не тасуется.
Отображение графика процесса обучения и подавление подробного выхода.
miniBatchSize = 128; numIterationsPerEpoch = floor(numObservations/miniBatchSize); options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','never', ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть с помощью trainNetwork
функция.
net = trainNetwork(tdsTrain,lgraph,options);
Классифицируйте тип события трех новых отчетов. Создайте строковые массивы, содержащий новые отчеты.
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
Предварительно обработайте текстовые данные, используя шаги предварительной обработки в качестве обучающих документов.
XNew = preprocessText(reportsNew,sequenceLength,emb);
Классификация новых последовательностей с помощью обученной сети LSTM.
labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
The readLabels
функция создает копию tabularTextDatastore
ttds объекта
и считывает метки из labelName
столбец.
function labels = readLabels(ttds,labelName) ttdsNew = copy(ttds); ttdsNew.SelectedVariableNames = labelName; tbl = readall(ttdsNew); labels = tbl.(labelName); end
The transformTextData
функция принимает данные, считанные из tabularTextDatastore
и возвращает таблицу предикторов и откликов. Предикторы 1-by- sequenceLength
-by-C массивы векторов слов, заданные словом embedding emb
, где C - размерность вложения. Ответы являются категориальными метками над классами в classNames
.
function dataTransformed = transformTextData(data,sequenceLength,emb,classNames) % Preprocess documents. textData = data{:,1}; % Prepocess text dataTransformed = preprocessText(textData,sequenceLength,emb); % Read labels. labels = data{:,2}; responses = categorical(labels,classNames); % Convert data to table. dataTransformed.Responses = responses; end
The preprocessTextData
функция принимает текстовые данные, длину последовательности и встраивание слова и выполняет следующие шаги:
Токенизируйте текст.
Преобразуйте текст в строчный.
Преобразует документы в последовательности векторов слов заданной длины с помощью встраивания.
Изменяет форму векторных последовательностей слов для ввода в сеть.
function tbl = preprocessText(textData,sequenceLength,emb) documents = tokenizedDocument(textData); documents = lower(documents); % Convert documents to embeddingDimension-by-sequenceLength-by-1 images. predictors = doc2sequence(emb,documents,'Length',sequenceLength); % Reshape images to be of size 1-by-sequenceLength-embeddingDimension. predictors = cellfun(@(X) permute(X,[3 2 1]),predictors,'UniformOutput',false); tbl = table; tbl.Predictors = predictors; end
doc2sequence
| fastTextWordEmbedding
| tokenizedDocument
| wordcloud
| wordEmbedding
| batchNormalizationLayer
(Deep Learning Toolbox) | convolution2dLayer
(Deep Learning Toolbox) | layerGraph
(Deep Learning Toolbox) | trainingOptions
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox)