Обучите нейронную сеть для глубокого обучения для классификации новых изображений

Этот пример показывает, как использовать передачу обучения для переобучения сверточной нейронной сети для классификации нового набора изображений.

Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети узнали представления богатых функций для широкой области значений изображений. Сеть принимает изображение как вход, а затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применениях глубокого обучения. Можно взять предварительно обученную сеть и использовать ее как начальная точка для изучения новой задачи. Подстройка сети с передачей обучения обычно намного быстрее и проще, чем настройка сети с нуля со случайным образом инициализированными весами. Можно быстро перенести выученные функции в новую задачу с помощью меньшего количества обучающих изображений.

Загрузка данных

Разархивируйте и загружайте новые изображения как image datastore. Этот очень маленький набор данных содержит всего 75 изображений. Разделите данные на наборы данных для обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.

unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames'); 
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);

Загрузка предварительно обученной сети

Загрузите предварительно обученную сеть GoogLeNet. Если Модель Deep Learning Toolbox™ для пакета поддержки GoogLeNet Network не установлена, то программное обеспечение предоставляет ссылку загрузки.

Чтобы попробовать другую предварительно обученную сеть, откройте этот пример в MATLAB ® и выберите другую сеть. Например, можно попробовать squeezenet, сеть, которая даже быстрее, чем googlenet. Этот пример можно запустить с другими предварительно обученными сетями. Список всех доступных сетей см. в разделе Загрузка предварительно обученных сетей.

net = googlenet;

Использование analyzeNetwork отображение интерактивной визуализации сетевой архитектуры и подробной информации о слоях сети.

analyzeNetwork(net)

Первый элемент Layers свойство сети является входным слоем изображения. Для сети GoogLeNet этот слой требует изображений входа размера 224 на 224 на 3, где 3 количество цветных каналов. Другие сети могут потребовать входных изображений с различными размерами. Для примера сети Xception требуются изображения размера 299 299 3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

inputSize = net.Layers(1).InputSize;

Замена конечных слоев

Сверточные слои сети извлекают изображение, функции последний выучиваемый слой и конечный слой классификации используют для классификации входа изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, которые сеть извлекает в вероятности классов, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть для классификации новых изображений, замените эти два слоя новыми слоями, адаптированными к новому набору данных.

Извлеките график слоев из обученной сети. Если сеть является SeriesNetwork объект, такой как AlexNet, VGG-16 или VGG-19, затем преобразует список слоев в net.Layers в график слоев.

if isa(net,'SeriesNetwork') 
  lgraph = layerGraph(net.Layers); 
else
  lgraph = layerGraph(net);
end 

Найдите имена двух слоев, которые нужно заменить. Вы можете сделать это вручную или использовать вспомогательную функцию findLayersToReplace, чтобы автоматически найти эти слои.

[learnableLayer,classLayer] = findLayersToReplace(lgraph);
[learnableLayer,classLayer] 
ans = 
  1×2 Layer array with layers:

     1   'loss3-classifier'   Fully Connected         1000 fully connected layer
     2   'output'             Classification Output   crossentropyex with 'tench' and 999 other classes

В большинстве сетей последний слой с усвояемыми весами является полносвязным слоем. Замените этот полностью соединенный слой новым полностью соединенным слоем с количеством выходов, равным количеству классов в новом наборе данных (5, в этом примере). В некоторых сетях, таких как SqueezeNet, последний обучаемый слой является сверточным слоем 1 на 1. В этом случае замените сверточный слой на новый сверточный слой с количеством фильтров, равным количеству классов. Чтобы учиться в новом слое быстрее, чем в переданных слоях, увеличьте коэффициенты скорости обучения слоя.

numClasses = numel(categories(imdsTrain.Labels));

if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        'Name','new_fc', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
    
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
end

lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

Слой классификации задает выходные классы сети. Замените слой классификации новым слоем без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);

Чтобы проверить, правильно ли соединены новые слои, постройте график нового графика слоев и увеличьте изображение последних слоев сети.

figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])

Замораживание начальных слоев

Теперь сеть готова к переобучению на новом наборе изображений. Опционально можно «заморозить» веса более ранних слоев в сети, установив скорости обучения в этих слоях на нуль. Во время обучения, trainNetwork не обновляет параметры замороженных слоев. Поскольку градиенты замороженных слоев не нужно вычислять, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных является маленьким, то замораживание ранее слоев сети может также предотвратить сверхподбор кривой этих слоев к новому набору данных.

Извлеките слои и связи графика слоев и выберите, какие слои замораживать. В GoogLeNet первые 10 слоев разобрали начальный 'ствол' сети. Используйте вспомогательную функцию freezeWeights, чтобы задать нулевые скорости обучения в первых 10 слоях. Используйте вспомогательную функцию createLgraphUsingConnections, чтобы повторно соединить все слои в исходном порядке. Новый график слоев содержит те же слои, но со скоростями обучения более ранних слоев, установленной на нуль.

layers = lgraph.Layers;
connections = lgraph.Connections;

layers(1:10) = freezeWeights(layers(1:10));
lgraph = createLgraphUsingConnections(layers,connections);

Обучите сеть

Сеть требует изображений входа размера 224 на 224 на 3, но у изображений в изображении datastore есть различные размеры. Используйте хранилище данных дополненных изображений, чтобы автоматически изменить размер обучающих изображений. Задайте дополнительные операции увеличения для выполнения на обучающих изображениях: случайным образом разверните обучающие изображения вдоль вертикальной оси и случайным образом перемещайте их до 30 пикселей и масштабируйте до 10% по горизонтали и вертикали. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшего увеличения данных, используйте хранилище datastore с дополненными изображениями, не задавая никаких дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Задайте InitialLearnRate к небольшому значению для замедления обучения в перенесенных слоях, которые еще не заморожены. На предыдущем шаге вы увеличили коэффициенты скорости обучения для последнего обучаемого слоя, чтобы ускорить обучение в новых конечных слоях. Эта комбинация настроек скорости обучения приводит к быстрому обучению в новых слоях, более медленному обучению в средних слоях и отсутствию обучения в более ранних, замороженных слоях.

Укажите количество эпох для обучения. При выполнении передачи обучения вам не нужно тренироваться на столько эпох. Эпоха - это полный цикл обучения на целом наборе обучающих данных. Укажите размер мини-пакета и данные валидации. Вычислите точность валидации один раз в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть с помощью обучающих данных. По умолчанию trainNetwork использует графический процессор, если он доступен. Для этого требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox). В противном случае trainNetwork использует центральный процессор. Можно также задать окружение выполнения с помощью 'ExecutionEnvironment' Аргумент пары "имя-значение" из trainingOptions. Поскольку набор данных настолько мал, обучение происходит быстро.

net = trainNetwork(augimdsTrain,lgraph,options);

Классификация изображений валидации

Классифицируйте изображения валидации с помощью тонкой настройки сети и вычислите точность классификации.

[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 0.9000

Отобразите четыре изображения валидации выборки с предсказанными метками и предсказанными вероятностями изображений, имеющих эти метки.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

Ссылки

[1] Сегеди, Кристиан, Вэй Лю, Янцин Цзи, Пьер Сермане, Скотт Рид, Драгомир Ангуэлов, Думитру Эрхан, Винсент Ванхукке и Эндрю Рабинович. «Все глубже со свертками». В материалах конференции IEEE по компьютерному зрению и распознаванию шаблонов, стр. 1-9. 2015.

См. также

| | | | | | | | | |

Похожие темы