exponenta event banner

Обучение сети глубокого обучения с вложенными слоями

В этом примере показано, как обучить сеть вложенным слоям.

Чтобы создать пользовательский слой, который сам определяет график слоев, можно указать dlnetwork объект как обучаемый параметр. Этот способ известен как сетчатая композиция. Состав сетевого графика можно использовать для:

  • Создайте один пользовательский слой, представляющий блок обучаемых слоев, например остаточный блок.

  • Создайте сеть с потоком управления. Например, сеть с секцией, которая может динамически изменяться в зависимости от входных данных.

  • Создание сети с контурами. Например, сеть с секциями, которые подают выходные данные обратно в себя.

Дополнительные сведения см. в разделе Состав сети глубокого обучения.

В этом примере показано, как обучить сеть с помощью пользовательских слоев, представляющих остаточные блоки, каждый из которых содержит несколько слоев свертки, нормализации групп и ReLU с соединением пропуска. Пример создания остаточной сети без использования пользовательских слоев см. в разделе Подготовка остаточной сети для классификации изображений.

Остаточные соединения являются популярным элементом в сверточных нейронных сетевых архитектурах. Остаточная сеть - это тип сети, которая имеет остаточные (или короткие) соединения, которые обходят основные сетевые уровни. Использование остаточных соединений улучшает градиентный поток через сеть и позволяет обучать более глубокие сети. Эта увеличенная глубина сети может дать более высокую точность при выполнении более сложных задач.

В этом примере используется пользовательский слой residualBlockLayer, который содержит обучаемый блок уровней, состоящий из слоев свертки, нормализации группы, ReLU и уровней сложения, а также включает в себя соединение пропуска и необязательный уровень свертки и уровень нормализации группы в соединении пропуска. На этой схеме показана остаточная блочная структура.

Пример создания пользовательского слоя residualBlockLayer, см. раздел Определение вложенного слоя глубокого обучения.

Подготовка данных

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте хранилище данных изображения, содержащее фотографии.

datasetFolder = fullfile(imageFolder);
imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Разбиение данных на наборы данных обучения и проверки. Используйте 70% изображений для обучения и 30% для проверки.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

Просмотр количества классов набора данных.

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений. Изменение размеров и увеличение изображений для обучения с помощью imageDataAugmenter объект:

  • Случайное отображение изображений по вертикальной оси.

  • Случайное перемещение изображений до 30 пикселей по вертикали и горизонтали.

  • Произвольный поворот изображений до 45 градусов по часовой стрелке и против часовой стрелки.

  • Случайное масштабирование изображений до 10% по вертикали и горизонтали.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandRotation',[-45 45], ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);

Создайте хранилище данных дополненного изображения, содержащее обучающие данные, с помощью центра увеличения данных изображения. Для автоматического изменения размеров изображений в соответствии с входным размером сети укажите высоту и ширину входного размера сети. В этом примере используется сеть с входным размером [224 224 3].

inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);

Определение сетевой архитектуры

Определение остаточной сети с шестью остаточными блоками с использованием пользовательского слоя residualBlockLayer. Чтобы получить доступ к этому слою, откройте пример как живой сценарий. Пример создания этого пользовательского слоя см. в разделе Определение вложенного слоя глубокого обучения.

Поскольку необходимо указать входной размер входного слоя dlnetwork при создании слоя необходимо указать размер ввода. Для определения входного размера слоя можно использовать analyzeNetwork и проверьте размер активизаций предыдущего слоя.

numFilters = 32;

layers = [
    imageInputLayer(inputSize)
    convolution2dLayer(7,numFilters,'Stride',2,'Padding','same')
    groupNormalizationLayer('all-channels')
    reluLayer
    maxPooling2dLayer(3,'Stride',2)
    residualBlockLayer(numFilters)
    residualBlockLayer(numFilters)
    residualBlockLayer(2*numFilters,'Stride',2,'IncludeSkipConvolution',true)
    residualBlockLayer(2*numFilters)
    residualBlockLayer(4*numFilters,'Stride',2,'IncludeSkipConvolution',true)
    residualBlockLayer(4*numFilters)
    globalAveragePooling2dLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  15×1 Layer array with layers:

     1   ''   Image Input              224×224×3 images with 'zerocenter' normalization
     2   ''   Convolution              32 7×7 convolutions with stride [2  2] and padding 'same'
     3   ''   Group Normalization      Group normalization
     4   ''   ReLU                     ReLU
     5   ''   Max Pooling              3×3 max pooling with stride [2  2] and padding [0  0  0  0]
     6   ''   Residual Block           Residual block with 32 filters, stride 1
     7   ''   Residual Block           Residual block with 32 filters, stride 1
     8   ''   Residual Block           Residual block with 64 filters, stride 2, and skip convolution
     9   ''   Residual Block           Residual block with 64 filters, stride 1
    10   ''   Residual Block           Residual block with 128 filters, stride 2, and skip convolution
    11   ''   Residual Block           Residual block with 128 filters, stride 1
    12   ''   Global Average Pooling   Global average pooling
    13   ''   Fully Connected          5 fully connected layer
    14   ''   Softmax                  softmax
    15   ''   Classification Output    crossentropyex

Железнодорожная сеть

Укажите параметры обучения:

  • Обучение сети с размером мини-партии 128.

  • Тасуйте данные каждую эпоху.

  • Проверяйте сеть один раз в эпоху, используя данные проверки.

  • Отображение хода обучения на графике и отключение подробных выходных данных.

miniBatchSize = 128;
numIterationsPerEpoch = floor(augimdsTrain.NumObservations/miniBatchSize);

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',numIterationsPerEpoch, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучение сети с помощью trainNetwork функция. По умолчанию trainNetwork использует графический процессор, если он доступен, в противном случае использует центральный процессор. Для обучения графическому процессору требуются параллельные вычислительные Toolbox™ и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). Можно также указать среду выполнения с помощью 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions.

net = trainNetwork(augimdsTrain,layers,options);

Оценка обученной сети

Рассчитайте окончательную точность сети на обучающем наборе (без увеличения данных) и валидационном наборе. Точность - это доля изображений, которые сеть классифицирует правильно.

YPred = classify(net,augimdsValidation);
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
accuracy = 0.7230

Визуализация точности классификации в матрице путаницы. Отображение точности и отзыва для каждого класса с помощью сводок столбцов и строк.

figure
confusionchart(YValidation,YPred, ...
    'RowSummary','row-normalized', ...
    'ColumnSummary','column-normalized');

Можно отобразить четыре образца подтверждающих изображений с прогнозируемыми метками и прогнозируемыми вероятностями изображений с этими метками, используя следующий код.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title("Predicted class: " + string(label));
end

Ссылки

См. также

| |

Связанные темы