В этом примере показано, как обучить сеть с вложенными слоями.
Чтобы создать пользовательский слой, который сам задает график слоев, можно задать dlnetwork
возразите как настраиваемый параметр. Этот метод известен как сетевой состав. Можно использовать сетевой состав для:
Создайте один пользовательский слой, который представляет блок learnable слоев, например, остаточный блок.
Создайте сеть с потоком управления. Например, сеть с разделом, который может динамически измениться в зависимости от входных данных.
Создайте сеть с циклами. Например, сеть с разделами, которые подают выход назад в себя.
Для получения дополнительной информации смотрите Состав Нейронной сети для глубокого обучения.
В этом примере показано, как обучить сеть с помощью пользовательских слоев, представляющих остаточные блоки, каждый содержащий несколько свертка, нормализация группы и слои ReLU со связью пропуска. Для примера, показывающего, как создать остаточную сеть, не используя пользовательские слои, смотрите, Обучают Остаточную Сеть для Классификации Изображений.
Остаточные связи являются популярным элементом в архитектурах сверточной нейронной сети. Остаточная сеть является типом сети, которая имеет невязку (или ярлык) связи, которые обходят основные слоя сети. Используя остаточные связи улучшает поток градиента через сеть и включает обучение более глубоких сетей. Эта увеличенная сетевая глубина может дать к более высокой точности на более трудных задачах.
Этот пример использует пользовательский слой residualBlockLayer
, который содержит learnable блок слоев, состоящих из свертки, нормализации группы, ReLU и слоев сложения, и также включает связь пропуска и дополнительный слой свертки и слой нормализации группы в связи пропуска. Эта схема подсвечивает остаточную блочную структуру.
Для примера, показывающего, как создать пользовательский слой residualBlockLayer
, смотрите Задают Вложенный Слой Глубокого обучения.
Загрузите и извлеките Цветочный набор данных [1].
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz'; downloadFolder = tempdir; filename = fullfile(downloadFolder,'flower_dataset.tgz'); imageFolder = fullfile(downloadFolder,'flower_photos'); if ~exist(imageFolder,'dir') disp('Downloading Flowers data set (218 MB)...') websave(filename,url); untar(filename,downloadFolder) end
Создайте datastore изображений, содержащий фотографии.
datasetFolder = fullfile(imageFolder); imds = imageDatastore(datasetFolder, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames');
Разделите данные в наборы данных обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Просмотрите количество классов набора данных.
classes = categories(imds.Labels); numClasses = numel(classes)
numClasses = 5
Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений. Измените размер и увеличьте изображения для обучения с помощью imageDataAugmenter
объект:
Случайным образом отразите изображения на вертикальной оси.
Случайным образом переведите изображения до 30 пикселей вертикально и горизонтально.
Случайным образом вращайте изображения до 45 градусов по часовой стрелке и против часовой стрелки.
Случайным образом масштабируйте изображения до 10% вертикально и горизонтально.
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandRotation',[-45 45], ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange);
Создайте увеличенный datastore изображений, содержащий обучающие данные с помощью увеличения данных изображения. Чтобы автоматически изменить размер изображений к сетевому входному размеру, задайте высоту и ширину входного размера сети. Этот пример использует сеть с входным размером [224 224 3]
.
inputSize = [224 224 3];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);
Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
Задайте остаточную сеть с шестью остаточными блоками с помощью пользовательского слоя residualBlockLayer
. Чтобы получить доступ к этому слою, откройте пример как live скрипт. Для примера, показывающего, как создать этот пользовательский слой, смотрите, Задают Вложенный Слой Глубокого обучения.
Поскольку необходимо задать входной размер входного слоя dlnetwork
объект, необходимо задать входной размер при создании слоя. Чтобы помочь определить входной размер к слою, можно использовать analyzeNetwork
функционируйте и проверяйте размер активаций предыдущего слоя.
numFilters = 32; layers = [ imageInputLayer(inputSize) convolution2dLayer(7,numFilters,'Stride',2,'Padding','same') groupNormalizationLayer('all-channels') reluLayer maxPooling2dLayer(3,'Stride',2) residualBlockLayer(numFilters) residualBlockLayer(numFilters) residualBlockLayer(2*numFilters,'Stride',2,'IncludeSkipConvolution',true) residualBlockLayer(2*numFilters) residualBlockLayer(4*numFilters,'Stride',2,'IncludeSkipConvolution',true) residualBlockLayer(4*numFilters) globalAveragePooling2dLayer fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 15×1 Layer array with layers: 1 '' Image Input 224×224×3 images with 'zerocenter' normalization 2 '' Convolution 32 7×7 convolutions with stride [2 2] and padding 'same' 3 '' Group Normalization Group normalization 4 '' ReLU ReLU 5 '' Max Pooling 3×3 max pooling with stride [2 2] and padding [0 0 0 0] 6 '' Residual Block Residual block with 32 filters, stride 1 7 '' Residual Block Residual block with 32 filters, stride 1 8 '' Residual Block Residual block with 64 filters, stride 2, and skip convolution 9 '' Residual Block Residual block with 64 filters, stride 1 10 '' Residual Block Residual block with 128 filters, stride 2, and skip convolution 11 '' Residual Block Residual block with 128 filters, stride 1 12 '' Global Average Pooling Global average pooling 13 '' Fully Connected 5 fully connected layer 14 '' Softmax softmax 15 '' Classification Output crossentropyex
Задайте опции обучения:
Обучите сеть с мини-пакетным размером 128.
Переставьте данные каждая эпоха.
Проверьте сеть однажды в эпоху с помощью данных о валидации.
Отобразите прогресс обучения в графике и отключите многословный выход.
miniBatchSize = 128; numIterationsPerEpoch = floor(augimdsTrain.NumObservations/miniBatchSize); options = trainingOptions('adam', ... 'MiniBatchSize',miniBatchSize, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',numIterationsPerEpoch, ... 'Plots','training-progress', ... 'Verbose',false);
Обучите сеть с помощью trainNetwork
функция. По умолчанию, trainNetwork
использует графический процессор, если вы доступны, в противном случае, он использует центральный процессор. Обучение на графическом процессоре требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Можно также задать среду выполнения при помощи 'ExecutionEnvironment'
аргумент пары "имя-значение" trainingOptions
.
net = trainNetwork(augimdsTrain,layers,options);
Вычислите итоговую точность сети на наборе обучающих данных (без увеличения данных) и набор валидации. Точность является пропорцией изображений, которые сеть классифицирует правильно.
YPred = classify(net,augimdsValidation); YValidation = imdsValidation.Labels; accuracy = mean(YPred == YValidation)
accuracy = 0.7230
Визуализируйте точность классификации в матрице беспорядка. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца.
figure confusionchart(YValidation,YPred, ... 'RowSummary','row-normalized', ... 'ColumnSummary','column-normalized');
Можно отобразить четыре демонстрационных изображения валидации с предсказанными метками и предсказанными вероятностями изображений, имеющих те метки с помощью следующего кода.
idx = randperm(numel(imdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title("Predicted class: " + string(label)); end
Команда TensorFlow. Цветы http://download.tensorflow.org/example_images/flower_photos.tgz
checkLayer
| trainNetwork
| trainingOptions
| analyzeNetwork
| dlnetwork