Передача обучения с использованием предварительно обученной сети

В этом примере показано, как подстроить предварительно обученную сверточную нейронную сеть GoogLeNet для выполнения классификации на новом наборе изображений.

GoogLeNet обучен на более чем миллионе изображений и может классифицировать изображения в 1000 категорий объектов (таких как клавиатура, кофейная кружка, карандаш и многие животные). Сеть изучила представления богатых функций для широкой области значений изображений. Сеть принимает изображение как вход и выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. Подстройка сети с передачей обучения обычно намного быстрее и проще, чем обучение сети со случайным образом инициализированными весами с нуля. Можно быстро перенести выученные функции в новую задачу с помощью меньшего количества обучающих изображений.

Загрузка данных

Разархивируйте и загружайте новые изображения как image datastore. imageDatastore автоматически помечает изображения на основе имен папок и сохраняет данные как ImageDatastore объект. image datastore позволяет вам хранить большие данные изображения, включая данные, которые не помещаются в памяти, и эффективно считывать пакеты изображений во время обучения сверточной нейронной сети.

unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

Разделите данные на наборы данных для обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации. splitEachLabel разделяет image datastore на два новых хранилища данных.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');

Этот очень небольшой набор данных теперь содержит 55 обучающих изображений и 20 изображений для валидации. Отобразите некоторые образцовые изображения.

numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
    subplot(4,4,i)
    I = readimage(imdsTrain,idx(i));
    imshow(I)
end

Загрузка предварительно обученной сети

Загрузите предварительно обученную нейронную сеть GoogLeNet. Если Deep Learning Toolbox™ Model для GoogLeNet Network не установлена, то программное обеспечение предоставляет ссылку загрузки.

net = googlenet;

Использование deepNetworkDesigner отображение интерактивной визуализации сетевой архитектуры и подробной информации о слоях сети.

deepNetworkDesigner(net)

Первый слой, который является слоем входа изображений, требует входа изображений размера 224 224 3, где 3 количество цветовых каналов.

inputSize = net.Layers(1).InputSize
inputSize = 1×3

   224   224     3

Замена конечных слоев

Слой полносвязного слоя и классификации предварительно обученной сети net настроены для 1000 классов. Эти два слоя, loss3-classifier и output в GoogLeNet содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для классификации новых изображений, замените эти два слоя новыми слоями, адаптированными к новому набору данных.

Извлеките график слоев из обученной сети.

lgraph = layerGraph(net); 

Замените полностью соединенный слой новым полностью соединенным слоем, количество выходов которого равняется количеству классов. Чтобы сделать обучение быстрее в новых слоях, чем в переданных слоях, увеличьте WeightLearnRateFactor и BiasLearnRateFactor значения полносвязного слоя.

numClasses = numel(categories(imdsTrain.Labels))
numClasses = 5
newLearnableLayer = fullyConnectedLayer(numClasses, ...
    'Name','new_fc', ...
    'WeightLearnRateFactor',10, ...
    'BiasLearnRateFactor',10);
    
lgraph = replaceLayer(lgraph,'loss3-classifier',newLearnableLayer);

Слой классификации задает выходные классы сети. Замените слой классификации новым слоем без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,'output',newClassLayer);

Обучите сеть

Сеть требует изображений входа размера 224 на 224 на 3, но у изображений в хранилищах данных изображения есть различные размеры. Используйте хранилище данных дополненных изображений, чтобы автоматически изменить размер обучающих изображений. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом разверните обучающие изображения вдоль вертикальной оси и случайным образом переведите их до 30 пикселей горизонтально и вертикально. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Для передачи обучения сохраните функции из ранних слоев предварительно обученной сети (веса переданных слоев). Чтобы замедлить обучение в перенесенных слоях, установите начальную скорость обучения на небольшое значение. На предыдущем шаге вы увеличили коэффициенты скорости обучения для полносвязного слоя, чтобы ускорить обучение в новых конечных слоях. Эта комбинация настроек скорости обучения приводит к быстрому обучению только в новых слоях и более медленному обучению в других слоях. При выполнении передачи обучения вам не нужно тренироваться на столько эпох. Эпоха - это полный цикл обучения на целом наборе обучающих данных. Укажите размер мини-пакета и данные валидации. Программное обеспечение проверяет сеть каждый ValidationFrequency итерации во время обучения.

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',3, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть, состоящую из переданного и нового слоев. По умолчанию trainNetwork использует графический процессор, если он доступен. Для этого требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае используется центральный процессор. Можно также задать окружение выполнения с помощью 'ExecutionEnvironment' Аргумент пары "имя-значение" из trainingOptions.

netTransfer = trainNetwork(augimdsTrain,lgraph,options);

Классификация изображений валидации

Классифицируйте изображения валидации с помощью тонкой настройки сети.

[YPred,scores] = classify(netTransfer,augimdsValidation);

Отобразите четыре изображения валидации образцов с их предсказанными метками.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label));
end

Вычислите точность классификации на наборе валидации. Точность является частью меток, которые сеть предсказывает правильно.

YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
accuracy = 1

Для советов по повышению точности классификации смотрите Советы по глубокому обучению и Рекомендации.

Ссылки

[1] Крижевский, Алекс, Илья Суцкевер, и Джеффри Э. Хинтон. «Классификация ImageNet с глубокими сверточными нейронными сетями». Усовершенствования в системах нейронной обработки информации 25 (2012).

[2] Сегеди, Кристиан, Вэй Лю, Янцин Цзи, Пьер Сермане, Скотт Рид, Драгомир Ангуэлов, Думитру Эрхан, Винсент Ванхукке и Эндрю Рабинович. «Все глубже со свертками». Материалы конференции IEEE по компьютерному зрению и распознаванию шаблонов (2015): 1-9.

См. также

| | | | |

Похожие темы