В этом примере показано, как обучить сеть с вложенными слоями.
Чтобы создать пользовательский слой, который сам определяет график слоев, можно задать dlnetwork
объект как настраиваемый параметр. Этот способ известен как состав сети. Можно использовать сетевую композицию для:
Создайте один пользовательский слой, который представляет блок обучаемых слоев, например, остаточный блок.
Создайте сеть с потоком управления. Например, сеть с секцией, которая может динамически изменяться в зависимости от входных данных.
Создайте сеть с циклами. Например, сеть с секциями, которые подают выход назад в себя.
Для получения дополнительной информации смотрите Нейронную сеть для глубокого обучения Composition.
В этом примере показано, как обучить сеть с помощью пользовательских слоев, представляющих остаточные блоки, каждый из которых содержит несколько слоев свертки, нормализации группы и ReLU с пропущенным соединением. Пример, показывающий создание невязки сети без использования пользовательских слоев, см. Train Невязки Network for Image Classification.
Остаточные соединения являются популярным элементом в архитектурах сверточных нейронных сетей. Невязка сеть является типом сети, которая имеет невязку (или ярлык) соединения, которые обходят основные слои сети. Использование остаточных соединений улучшает градиентный поток через сеть и позволяет обучать более глубокие сети. Эта увеличенная глубина сети может привести к более высоким точностям для более сложных задач.
Этот пример использует пользовательский слой residualBlockLayer
, который содержит обучаемый блок слоев, состоящий из свертки, нормализации группы, ReLU и слоев сложения, а также включает в себя пропускающее соединение и дополнительный слой свертки и слой нормализации группы в пропускающем соединении. Эта схема подсвечивает структуру остаточных блоков.
Для примера, показывающего, как создать пользовательский слой residualBlockLayer
, см. «Определение вложенного слоя глубокого обучения».
Загрузите и извлеките набор данных Flowers [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте datastore, содержащее фотографии.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
Разделите данные на наборы данных для обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Просмотрите количество классов набора данных.
classes = categories(imds.Labels); numClasses = numel(classes)
numClasses = 5
Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений. Измените размер и увеличьте изображения для обучения с помощью imageDataAugmenter
объект:
Случайным образом отражайте изображения в вертикальной оси.
Случайным образом перемещайте изображения до 30 пикселей вертикально и горизонтально.
Случайным образом поверните изображения до 45 степеней по часовой стрелке и против часовой стрелки.
Случайным образом масштабируйте изображения до 10% по вертикали и горизонтали.
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandRotation',[-45 45], ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange);
Создайте дополненный datastore, содержащий обучающие данные, с помощью дополнителя изображений. Чтобы автоматически изменить размер изображений на размер входа сети, укажите высоту и ширину входного размера сети. Этот пример использует сеть с размером входа [224 224 3]
.
inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
Определите невязку сеть с шестью блоками невязки с помощью пользовательского слоя residualBlockLayer
. Чтобы получить доступ к этому слою, откройте пример как live скрипт. Пример создания пользовательского слоя см. в разделе «Определение вложенного слоя глубокого обучения».
Поскольку необходимо задать вход размер входа слоя dlnetwork
Объект при создании слоя необходимо задать размер входа. Чтобы помочь определить размер входного сигнала для слоя, можно использовать analyzeNetwork
и проверяйте размер активаций предыдущего слоя.
numFilters = 32; layers = [ imageInputLayer(inputSize) convolution2dLayer(7,numFilters,'Stride',2,'Padding','same') groupNormalizationLayer('all-channels') reluLayer maxPooling2dLayer(3,'Stride',2) residualBlockLayer(numFilters) residualBlockLayer(numFilters) residualBlockLayer(2*numFilters,'Stride',2,'IncludeSkipConvolution',true) residualBlockLayer(2*numFilters) residualBlockLayer(4*numFilters,'Stride',2,'IncludeSkipConvolution',true) residualBlockLayer(4*numFilters) globalAveragePooling2dLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 15×1 Layer array with layers: 1 '' Image Input 224×224×3 images with 'zerocenter' normalization 2 '' Convolution 32 7×7 convolutions with stride [2 2] and padding 'same' 3 '' Group Normalization Group normalization 4 '' ReLU ReLU 5 '' Max Pooling 3×3 max pooling with stride [2 2] and padding [0 0 0 0] 6 '' Residual Block Residual block with 32 filters, stride 1 7 '' Residual Block Residual block with 32 filters, stride 1 8 '' Residual Block Residual block with 64 filters, stride 2, and skip convolution 9 '' Residual Block Residual block with 64 filters, stride 1 10 '' Residual Block Residual block with 128 filters, stride 2, and skip convolution 11 '' Residual Block Residual block with 128 filters, stride 1 12 '' Global Average Pooling Global average pooling 13 '' Fully Connected 5 fully connected layer 14 '' Softmax softmax 15 '' Classification Output crossentropyex
Задайте опции обучения:
Обучите сеть с мини-пакетом размером 128.
Перетасовывайте данные каждую эпоху.
Проверяйте сеть один раз в эпоху, используя данные валидации.
Отображение процесса обучения на графике и отключение подробного выхода.
miniBatchSize = 128; numIterationsPerEpoch = floor(augimdsTrain.NumObservations/miniBatchSize); options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть с помощью trainNetwork
функция. По умолчанию trainNetwork
использует графический процессор, если он доступен, в противном случае используется центральный процессор. Для обучения на графическом процессоре требуется Parallel Computing Toolbox™ и поддерживаемое устройство GPU. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). Можно также задать окружение выполнения с помощью 'ExecutionEnvironment'
Аргумент пары "имя-значение" из trainingOptions
.
net = trainNetwork(augimdsTrain,layers,options);
Вычислите окончательную точность сети на набор обучающих данных (без увеличения данных) и наборе валидации. Точность является долей изображений, которые сеть классифицирует правильно.
YPred = classify(net,augimdsValidation); YValidation = imdsValidation.Labels; accuracy = mean(YPred == YValidation)
accuracy = 0.7230
Визуализируйте точность классификации в матрице неточностей. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам.
figure confusionchart(YValidation,YPred, ... 'RowSummary','row-normalized', ... 'ColumnSummary','column-normalized');
Можно отобразить четыре изображения валидации образцов с предсказанными метками и предсказанными вероятностями изображений, имеющих эти метки, используя следующий код.
idx = randperm(numel(imdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title("Predicted class: " + string(label)); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
checkLayer
| trainingOptions
| trainNetwork