Обучите Нейронную сеть для глубокого обучения с вложенными слоями

В этом примере показано, как обучить сеть с вложенными слоями.

Чтобы создать пользовательский слой, который сам определяет график слоев, можно задать dlnetwork объект как настраиваемый параметр. Этот способ известен как состав сети. Можно использовать сетевую композицию для:

  • Создайте один пользовательский слой, который представляет блок обучаемых слоев, например, остаточный блок.

  • Создайте сеть с потоком управления. Например, сеть с секцией, которая может динамически изменяться в зависимости от входных данных.

  • Создайте сеть с циклами. Например, сеть с секциями, которые подают выход назад в себя.

Для получения дополнительной информации смотрите Нейронную сеть для глубокого обучения Composition.

В этом примере показано, как обучить сеть с помощью пользовательских слоев, представляющих остаточные блоки, каждый из которых содержит несколько слоев свертки, нормализации группы и ReLU с пропущенным соединением. Пример, показывающий создание невязки сети без использования пользовательских слоев, см. Train Невязки Network for Image Classification.

Остаточные соединения являются популярным элементом в архитектурах сверточных нейронных сетей. Невязка сеть является типом сети, которая имеет невязку (или ярлык) соединения, которые обходят основные слои сети. Использование остаточных соединений улучшает градиентный поток через сеть и позволяет обучать более глубокие сети. Эта увеличенная глубина сети может привести к более высоким точностям для более сложных задач.

Этот пример использует пользовательский слой residualBlockLayer, который содержит обучаемый блок слоев, состоящий из свертки, нормализации группы, ReLU и слоев сложения, а также включает в себя пропускающее соединение и дополнительный слой свертки и слой нормализации группы в пропускающем соединении. Эта схема подсвечивает структуру остаточных блоков.

Для примера, показывающего, как создать пользовательский слой residualBlockLayer, см. «Определение вложенного слоя глубокого обучения».

Подготовка данных

Загрузите и извлеките набор данных Flowers [1].

url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

imageFolder = fullfile(downloadFolder,'flower_photos');
if ~exist(imageFolder,'dir')
    disp('Downloading Flowers data set (218 MB)...')
    websave(filename,url);
    untar(filename,downloadFolder)
end

Создайте datastore, содержащее фотографии.

datasetFolder = fullfile(imageFolder);
imds = imageDatastore(datasetFolder, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Разделите данные на наборы данных для обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

Просмотрите количество классов набора данных.

classes = categories(imds.Labels);
numClasses = numel(classes)
numClasses = 5

Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений. Измените размер и увеличьте изображения для обучения с помощью imageDataAugmenter объект:

  • Случайным образом отражайте изображения в вертикальной оси.

  • Случайным образом перемещайте изображения до 30 пикселей вертикально и горизонтально.

  • Случайным образом поверните изображения до 45 степеней по часовой стрелке и против часовой стрелки.

  • Случайным образом масштабируйте изображения до 10% по вертикали и горизонтали.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandRotation',[-45 45], ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);

Создайте дополненный datastore, содержащий обучающие данные, с помощью дополнителя изображений. Чтобы автоматически изменить размер изображений на размер входа сети, укажите высоту и ширину входного размера сети. Этот пример использует сеть с размером входа [224 224 3].

inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);

Определение сетевой архитектуры

Определите невязку сеть с шестью блоками невязки с помощью пользовательского слоя residualBlockLayer. Чтобы получить доступ к этому слою, откройте пример как live скрипт. Пример создания пользовательского слоя см. в разделе «Определение вложенного слоя глубокого обучения».

Поскольку необходимо задать вход размер входа слоя dlnetwork Объект при создании слоя необходимо задать размер входа. Чтобы помочь определить размер входного сигнала для слоя, можно использовать analyzeNetwork и проверяйте размер активаций предыдущего слоя.

numFilters = 32;

layers = [
    imageInputLayer(inputSize)
    convolution2dLayer(7,numFilters,'Stride',2,'Padding','same')
    groupNormalizationLayer('all-channels')
    reluLayer
    maxPooling2dLayer(3,'Stride',2)
    residualBlockLayer(numFilters)
    residualBlockLayer(numFilters)
    residualBlockLayer(2*numFilters,'Stride',2,'IncludeSkipConvolution',true)
    residualBlockLayer(2*numFilters)
    residualBlockLayer(4*numFilters,'Stride',2,'IncludeSkipConvolution',true)
    residualBlockLayer(4*numFilters)
    globalAveragePooling2dLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  15×1 Layer array with layers:

     1   ''   Image Input              224×224×3 images with 'zerocenter' normalization
     2   ''   Convolution              32 7×7 convolutions with stride [2  2] and padding 'same'
     3   ''   Group Normalization      Group normalization
     4   ''   ReLU                     ReLU
     5   ''   Max Pooling              3×3 max pooling with stride [2  2] and padding [0  0  0  0]
     6   ''   Residual Block           Residual block with 32 filters, stride 1
     7   ''   Residual Block           Residual block with 32 filters, stride 1
     8   ''   Residual Block           Residual block with 64 filters, stride 2, and skip convolution
     9   ''   Residual Block           Residual block with 64 filters, stride 1
    10   ''   Residual Block           Residual block with 128 filters, stride 2, and skip convolution
    11   ''   Residual Block           Residual block with 128 filters, stride 1
    12   ''   Global Average Pooling   Global average pooling
    13   ''   Fully Connected          5 fully connected layer
    14   ''   Softmax                  softmax
    15   ''   Classification Output    crossentropyex

Обучите сеть

Задайте опции обучения:

  • Обучите сеть с мини-пакетом размером 128.

  • Перетасовывайте данные каждую эпоху.

  • Проверяйте сеть один раз в эпоху, используя данные валидации.

  • Отображение процесса обучения на графике и отключение подробного выхода.

miniBatchSize = 128;
numIterationsPerEpoch = floor(augimdsTrain.NumObservations/miniBatchSize);

options = trainingOptions('adam', ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',numIterationsPerEpoch, ...
    'Plots','training-progress', ...
    'Verbose',false);

Обучите сеть с помощью trainNetwork функция. По умолчанию trainNetwork использует графический процессор, если он доступен, в противном случае используется центральный процессор. Для обучения на графическом процессоре требуется Parallel Computing Toolbox™ и поддерживаемое устройство GPU. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Можно также задать окружение выполнения с помощью 'ExecutionEnvironment' Аргумент пары "имя-значение" из trainingOptions.

net = trainNetwork(augimdsTrain,layers,options);

Оценка обученной сети

Вычислите окончательную точность сети на набор обучающих данных (без увеличения данных) и наборе валидации. Точность является долей изображений, которые сеть классифицирует правильно.

YPred = classify(net,augimdsValidation);
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
accuracy = 0.7230

Визуализируйте точность классификации в матрице неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам.

figure
confusionchart(YValidation,YPred, ...
    'RowSummary','row-normalized', ...
    'ColumnSummary','column-normalized');

Можно отобразить четыре изображения валидации образцов с предсказанными метками и предсказанными вероятностями изображений, имеющих эти метки, используя следующий код.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title("Predicted class: " + string(label));
end

Ссылки

См. также

| |

Похожие темы