Классифицируйте текстовые данные Используя сверточную нейронную сеть

Этот пример показывает, как классифицировать текстовые данные с помощью сверточной нейронной сети.

Чтобы классифицировать текстовые данные с помощью сверток, необходимо преобразовать текстовые данные в изображения. Для этого заполните или обрежьте наблюдения, чтобы иметь постоянную длину S и преобразовать документы в последовательности векторов слова длины C использование встраивания слова. Можно затем представлять документ как 1 S изображением C (изображение с высотой 1, ширина S и каналы C).

Чтобы преобразовать текстовые данные от файла CSV до изображений, создайте объект tabularTextDatastore. Преобразование данные, считанные от объекта tabularTextDatastore до изображений для глубокого обучения путем вызова transform с пользовательской функцией преобразования. Функция transformTextData, перечисленная в конце примера, берет данные, считанные из datastore и предварительно обученного встраивания слова, и преобразовывает каждое наблюдение в массив векторов слова.

Этот пример обучает сеть с 1D сверточными фильтрами переменных ширин. Ширина каждого фильтра соответствует количество слов, которые фильтр видит (длина n-граммы). Сеть имеет несколько ответвлений сверточных слоев, таким образом, она может использовать различные длины n-граммы.

Загрузите предварительно обученный Word Embedding

Загрузите предварительно обученное fastText встраивание слова. Эта функция требует Модели Text Analytics Toolbox™ для fastText английских 16 миллиардов Лексем пакет поддержки Word Embedding. Если этот пакет поддержки не установлен, то функция обеспечивает ссылку на загрузку.

emb = fastTextWordEmbedding;

Загрузка данных

Создайте табличный текстовый datastore из данных в weatherReportsTrain.csv. Считайте данные из "event_narrative" и столбцов "event_type" только.

filenameTrain = "weatherReportsTrain.csv";
textName = "event_narrative";
labelName = "event_type";
ttdsTrain = tabularTextDatastore(filenameTrain,'SelectedVariableNames',[textName labelName]);

Предварительно просмотрите datastore.

ttdsTrain.ReadSize = 8;
preview(ttdsTrain)
ans=8×2 table
                                                                                             event_narrative                                                                                                 event_type     
    _________________________________________________________________________________________________________________________________________________________________________________________________    ___________________

    'Large tree down between Plantersville and Nettleton.'                                                                                                                                               'Thunderstorm Wind'
    'One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water.'    'Heavy Rain'       
    'NWS Columbia relayed a report of trees blown down along Tom Hall St.'                                                                                                                               'Thunderstorm Wind'
    'Media reported two trees blown down along I-40 in the Old Fort area.'                                                                                                                               'Thunderstorm Wind'
    'A few tree limbs greater than 6 inches down on HWY 18 in Roseland.'                                                                                                                                 'Thunderstorm Wind'
    'Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins.'                                                                                  'Thunderstorm Wind'
    'Tin roof ripped off house on Old Memphis Road near Billings Drive. Several large trees down in the area.'                                                                                           'Thunderstorm Wind'
    'Powerlines down at Walnut Grove and Cherry Lane roads.'                                                                                                                                             'Thunderstorm Wind'

Создайте пользовательское, преобразовывают функцию, которая преобразовывает данные, считанные от datastore до таблицы, содержащей предикторы и ответы. Функция transformTextData, перечисленная в конце примера, берет данные, считанные из объекта tabularTextDatastore, и возвращает таблицу предикторов и ответов. Предикторы 1 sequenceLength C массивами векторов слова, данных словом, встраивающим emb, где C является размерностью встраивания. Ответы являются категориальными метками по классам в classNames.

Считайте метки из данных тренировки с помощью функции readLabels, перечисленной в конце примера, и найдите уникальные имена классов.

labels = readLabels(ttdsTrain,labelName);
classNames = unique(labels);
numObservations = numel(labels);

Преобразуйте datastore с помощью transformTextData, функционируют и задают длину последовательности 100.

sequenceLength = 100;
tdsTrain = transform(ttdsTrain, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTrain = 
  TransformedDatastore with properties:

    UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore]
             Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)}
            IncludeInfo: 0

Предварительно просмотрите преобразованный datastore. Предикторы 1 S C массивами, где S является длиной последовательности, и C является количеством функций (размерность встраивания). Ответы являются категориальными метками.

preview(tdsTrain)
ans=8×2 table
        predictors            responses    
    __________________    _________________

    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Heavy Rain       
    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Thunderstorm Wind
    [1×100×300 single]    Thunderstorm Wind

Создайте преобразованный datastore, содержащий данные о валидации в weatherReportsValidation.csv с помощью тех же шагов.

filenameValidation = "weatherReportsValidation.csv";
ttdsValidation = tabularTextDatastore(filenameValidation,'SelectedVariableNames',[textName labelName]);

tdsValidation = transform(ttdsValidation, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsValidation = 
  TransformedDatastore with properties:

    UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore]
             Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)}
            IncludeInfo: 0

Архитектура сети Define

Задайте сетевую архитектуру для задачи классификации.

Следующие шаги описывают сетевую архитектуру.

  • Задайте входной размер 1 S C, где S является длиной последовательности, и C является количеством функций (размерность встраивания).

  • Для длин n-граммы 2, 3, 4, и 5, создают блоки слоев, содержащих сверточный слой, пакетный слой нормализации, слой ReLU, слой уволенного и макс. слой объединения.

  • Для каждого блока задайте 200 сверточных фильтров размера 1 на n и областей объединения размера 1 S, где N является длиной n-граммы.

  • Соедините входной слой с каждым блоком и конкатенируйте выходные параметры блоков с помощью слоя конкатенации глубины.

  • Чтобы классифицировать выходные параметры, включайте полносвязный слой с выходным размером K, softmax слоем и слоем классификации, где K является количеством классов.

Во-первых, в массиве слоя, задайте входной слой, первый блок для униграмм, слоя конкатенации глубины, полносвязного слоя, softmax слоя и слоя классификации.

numFeatures = emb.Dimension;
inputSize = [1 sequenceLength numFeatures];
numFilters = 200;

ngramLengths = [2 3 4 5];
numBlocks = numel(ngramLengths);

numClasses = numel(classNames);

Создайте график слоя, содержащий входной слой. Установите опцию нормализации на 'none' и имя слоя к 'input'.

layer = imageInputLayer(inputSize,'Normalization','none','Name','input');
lgraph = layerGraph(layer);

Для каждой из длин n-граммы создайте блок свертки, обработайте в пакетном режиме нормализацию, ReLU, уволенного и макс. объединение слоев. Соедините каждый блок с входным слоем.

for j = 1:numBlocks
    N = ngramLengths(j);
    
    block = [
        convolution2dLayer([1 N],numFilters,'Name',"conv"+N,'Padding','same')
        batchNormalizationLayer('Name',"bn"+N)
        reluLayer('Name',"relu"+N)
        dropoutLayer(0.2,'Name',"drop"+N)
        maxPooling2dLayer([1 sequenceLength],'Name',"max"+N)];
    
    lgraph = addLayers(lgraph,block);
    lgraph = connectLayers(lgraph,'input',"conv"+N);
end

Просмотрите сетевую архитектуру в графике.

figure
plot(lgraph)
title("Network Architecture")

Добавьте слой конкатенации глубины, полносвязный слой, softmax слой и слой классификации.

layers = [
    depthConcatenationLayer(numBlocks,'Name','depth')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','soft')
    classificationLayer('Name','classification')];

lgraph = addLayers(lgraph,layers);

figure
plot(lgraph)
title("Network Architecture")

Соедините макс. слои объединения со слоем конкатенации глубины и просмотрите итоговую сетевую архитектуру в графике.

for j = 1:numBlocks
    N = ngramLengths(j);
    lgraph = connectLayers(lgraph,"max"+N,"depth/in"+j);
end

figure
plot(lgraph)
title("Network Architecture")

Обучение сети

Задайте опции обучения:

  • Обучайтесь в течение 10 эпох с мини-пакетным размером 128.

  • Подтвердите сеть в каждую эпоху путем установки частоты валидации на количество итераций в эпоху.

  • Отобразите учебный график прогресса и подавите многословный вывод.

miniBatchSize = 128;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData',tdsValidation, ...
    'ValidationFrequency',numIterationsPerEpoch, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть с помощью функции trainNetwork.

net = trainNetwork(tdsTrain,lgraph,options);

Тестирование сети

Создайте преобразованный datastore, содержащий протянутые тестовые данные в weatherReportsTest.csv.

filenameTest = "weatherReportsTest.csv";
ttdsTest = tabularTextDatastore(filenameTest,'SelectedVariableNames',[textName labelName]);

tdsTest = transform(ttdsTest, @(data) transformTextData(data,sequenceLength,emb,classNames))
tdsTest = 
  TransformedDatastore with properties:

    UnderlyingDatastore: [1×1 matlab.io.datastore.TabularTextDatastore]
             Transforms: {@(data)transformTextData(data,sequenceLength,emb,classNames)}
            IncludeInfo: 0

Считайте метки из tabularTextDatastore.

labelsTest = readLabels(ttdsTest,labelName);
YTest = categorical(labelsTest,classNames);

Сделайте прогнозы на тестовых данных с помощью обучившего сеть.

YPred = classify(net,tdsTest);

Вычислите точность классификации на тестовые данные.

accuracy = mean(YPred == YTest)
accuracy = 0.8670

Функции

Функция readLabels создает копию объекта tabularTextDatastore ttds и читает метки из столбца labelName.

function labels = readLabels(ttds,labelName)

ttdsNew = copy(ttds);
ttdsNew.SelectedVariableNames = labelName;
tbl = readall(ttdsNew);
labels = tbl.(labelName);

end

Функция transformTextData берет данные, считанные из объекта tabularTextDatastore, и возвращает таблицу предикторов и ответов. Предикторы 1 sequenceLength C массивами векторов слова, данных словом, встраивающим emb, где C является размерностью встраивания. Ответы являются категориальными метками по классам в classNames.

function dataTransformed = transformTextData(data,sequenceLength,emb,classNames)

% Preprocess documents.
textData = data{:,1};
textData = lower(textData);
documents = tokenizedDocument(textData);

% Convert documents to embeddingDimension-by-sequenceLength-by-1 images.
predictors = doc2sequence(emb,documents,'Length',sequenceLength);

% Reshape images to be of size 1-by-sequenceLength-embeddingDimension.
predictors = cellfun(@(X) permute(X,[3 2 1]),predictors,'UniformOutput',false);

% Read labels.
labels = data{:,2};
responses = categorical(labels,classNames);

% Convert data to table.
dataTransformed = table(predictors,responses);

end

Смотрите также

| | | | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте