Передача обучения Используя предварительно обученную сеть

В этом примере показано, как подстроить предварительно обученную сверточную нейронную сеть GoogLeNet, чтобы выполнить классификацию на новом наборе изображений.

GoogLeNet был обучен на более чем миллионе изображений и может классифицировать изображения в 1 000 категорий объектов (таких как клавиатура, кофейная кружка, карандаш и многие животные). Сеть изучила богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа и выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. Подстройка сети с передачей обучения обычно намного быстрее и легче, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных изображений.

Загрузка данных

Разархивируйте и загрузите новые изображения как datastore изображений. imageDatastore автоматически помечает изображения на основе имен папок и хранит данные как ImageDatastore объект. Datastore изображений позволяет вам сохранить большие данные изображения, включая данные, которые не умещаются в памяти, и эффективно считать пакеты изображений во время обучения сверточной нейронной сети.

unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Разделите данные на наборы данных обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации. splitEachLabel разделяет datastore изображений в два новых хранилища данных.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

Этот очень небольшой набор данных теперь содержит 55 учебных изображений и 20 изображений валидации. Отобразите некоторые демонстрационные изображения.

numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
    subplot(4,4,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end

Загрузите предварительно обученную сеть

Загрузите предварительно обученную нейронную сеть GoogLeNet. Если Модель Deep Learning Toolbox™ для Сети GoogLeNet не установлена, то программное обеспечение обеспечивает ссылку на загрузку.

net = googlenet;

Используйте deepNetworkDesigner отобразить интерактивную визуализацию сетевой архитектуры и подробной информации о слоях сети.

deepNetworkDesigner(net)

Первый слой, который является входным слоем изображений, требует входных изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize
inputSize = 1×3

   224   224     3

Замените последние слои

Полносвязный слой и слой классификации предварительно обученной сети net сконфигурированы для 1 000 классов. Эти два слоя, loss3-classifier и output в GoogLeNet содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть, чтобы классифицировать новые изображения, замените эти два слоя на новые слои, адаптированные к новому набору данных.

Извлеките график слоев из обучившего сеть.

lgraph = layerGraph(net); 

Замените полносвязный слой на новый полносвязный слой, который имеет количество выходных параметров, равных количеству классов. Чтобы сделать изучение быстрее в новых слоях, чем в переданных слоях, увеличьте WeightLearnRateFactor и BiasLearnRateFactor значения полносвязного слоя.

numClasses = numel(categories(imdsTrain.Labels))
numClasses = 5
newLearnableLayer = fullyConnectedLayer(numClasses, ...
    'Name','new_fc', ...
    'WeightLearnRateFactor',10, ...
    'BiasLearnRateFactor',10);
    
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);

Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork автоматически устанавливает выходные классы слоя в учебное время.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);

Обучение сети

Сеть требует входных изображений размера 224 224 3, но изображения в хранилищах данных изображений имеют различные размеры. Используйте увеличенный datastore изображений, чтобы автоматически изменить размер учебных изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом инвертируйте учебные изображения вдоль вертикальной оси, и случайным образом переведите их до 30 пикселей горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Для передачи обучения сохраните функции от ранних слоев предварительно обученной сети (переданные веса слоя). Чтобы замедлить изучение в переданных слоях, установите начальную скорость обучения на маленькое значение. На предыдущем шаге вы увеличили факторы скорости обучения для полносвязного слоя, чтобы ускорить изучение в новых последних слоях. Эта комбинация настроек скорости обучения приводит к быстрому изучению только в новых слоях и более медленном изучении в других слоях. При использовании обучение с переносом вы не должны обучаться для как много эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Задайте мини-пакетный размер и данные о валидации. Программное обеспечение проверяет сеть каждый ValidationFrequency итерации во время обучения.

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',3, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть, состоящую из переданных и новых слоев. По умолчанию, trainNetwork использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае это использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment' аргумент пары "имя-значение" trainingOptions.

netTransfer = trainNetwork(augimdsTrain,lgraph,options);

Классифицируйте изображения валидации

Классифицируйте изображения валидации с помощью подстроенной сети.

[YPred,scores] = classify(netTransfer,augimdsValidation);

Отобразите четыре демонстрационных изображения валидации с их предсказанными метками.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label));
end

Вычислите точность классификации на набор валидации. Точность является частью меток, которые сеть предсказывает правильно.

YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
accuracy = 1

Для советов на улучшающейся точности классификации смотрите Советы Глубокого обучения и Приемы.

Ссылки

[1] Krizhevsky, Алекс, Илья Сутскевер и Джеффри Э. Хинтон. "Классификация ImageNet с Глубокими Сверточными нейронными сетями". Усовершенствования в нейронных системах обработки информации 25 (2012).

[2] Szegedy, христианин, Вэй Лю, Янцин Цзя, Пьер Сермане, Скотт Рид, Драгомир Ангуелов, Dumitru Erhan, Винсент Вэнхук и Эндрю Рэбинович. "Идя глубже со свертками". Продолжения конференции по IEEE по компьютерному зрению и распознаванию образов (2015): 1–9.

Смотрите также

| | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте