exponenta event banner

Обучение сети глубокого обучения классификации новых образов

В этом примере показано, как использовать обучение передаче для переподготовки сверточной нейронной сети для классификации нового набора изображений.

Предварительно подготовленные сети классификации изображений были обучены более чем миллиону изображений и могут классифицировать изображения на 1000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функций для широкого спектра изображений. Сеть принимает изображение в качестве входного, а затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Transfer learning обычно используется в приложениях для глубокого обучения. Предварительно подготовленную сеть можно использовать в качестве отправной точки для изучения новой задачи. Точная настройка сети с обучением переносу обычно намного быстрее и проще, чем обучение сети с нуля с произвольно инициализированными весами. Вы можете быстро перенести изученные функции на новую задачу, используя меньшее количество обучающих изображений.

Загрузить данные

Распакуйте и загрузите новые образы как хранилище данных образов. Этот очень маленький набор данных содержит только 75 изображений. Разделите данные на наборы данных обучения и проверки. Используйте 70% изображений для обучения и 30% для проверки.

unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames'); 
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);

Загрузить предварительно обученную сеть

Загрузите предварительно подготовленную сеть GoogLeNet. Если не установлен пакет поддержки Deep Learning Toolbox™ Model для сети GoogLeNet, то программное обеспечение предоставляет ссылку на загрузку.

Чтобы попробовать другую предварительно подготовленную сеть, откройте этот пример в MATLAB ® и выберите другую сеть. Например, можно попробоватьsqueezenet, сеть, которая даже быстрее, чем googlenet. Этот пример можно выполнить с другими предварительно подготовленными сетями. Список всех доступных сетей см. в разделе Загрузка предварительно обученных сетей.

net = googlenet;

Использовать analyzeNetwork отображение интерактивной визуализации сетевой архитектуры и подробной информации о сетевых уровнях.

analyzeNetwork(net)

Первый элемент Layers свойство сети - уровень ввода изображения. Для сети GoogLeNet этот слой требует входных изображений размера 224 на 224 на 3, где 3 количество цветных каналов. Для других сетей могут потребоваться входные изображения различных размеров. Например, сеть Xception требует изображений размера 299 на 299 на 3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

inputSize = net.Layers(1).InputSize;

Заменить конечные слои

Сверточные уровни сетевого извлеченного изображения отличаются тем, что последний обучаемый уровень и конечный уровень классификации используют для классификации входного изображения. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержат информацию о том, как объединить функции, извлекаемые сетью, в вероятности классов, значение потерь и прогнозируемые метки. Для переподготовки заранее обученной сети для классификации новых изображений замените эти два слоя новыми слоями, адаптированными к новому набору данных.

Извлеките график уровня из обученной сети. Если сеть является сетью SeriesNetwork объекты, такие как AlexNet, VGG-16 или VGG-19, затем преобразуют список слоев в net.Layers к графу слоев.

if isa(net,'SeriesNetwork') 
  lgraph = layerGraph(net.Layers); 
else
  lgraph = layerGraph(net);
end 

Найдите имена двух заменяемых слоев. Это можно сделать вручную или использовать вспомогательную функцию findSharingToReplace для автоматического поиска этих слоев.

[learnableLayer,classLayer] = findLayersToReplace(lgraph);
[learnableLayer,classLayer] 
ans = 
  1×2 Layer array with layers:

     1   'loss3-classifier'   Fully Connected         1000 fully connected layer
     2   'output'             Classification Output   crossentropyex with 'tench' and 999 other classes

В большинстве сетей последний уровень с обучаемыми весами является полностью подключенным. Замените этот полностью подключенный уровень новым полностью подключенным уровнем с количеством выходов, равным количеству классов в новом наборе данных (5, в данном примере). В некоторых сетях, таких как SqueeEcnet, последним обучаемым уровнем является сверточный уровень 1 на 1. В этом случае замените сверточный слой новым сверточным слоем с количеством фильтров, равным числу классов. Чтобы быстрее учиться в новом слое, чем в перенесенных слоях, увеличьте коэффициенты скорости обучения слоя.

numClasses = numel(categories(imdsTrain.Labels));

if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        'Name','new_fc', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
    
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
end

lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

Уровень классификации определяет выходные классы сети. Замените классификационный слой новым без меток классов. trainNetwork автоматически устанавливает выходные классы слоя во время обучения.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);

Чтобы проверить правильность соединения новых слоев, постройте график новых слоев и увеличьте изображение последних слоев сети.

figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])

Заморозить начальные слои

Сеть теперь готова переквалифицироваться на новый набор изображений. Дополнительно можно «заморозить» веса более ранних слоев в сети, установив нулевые показатели обучения в этих слоях. Во время обучения, trainNetwork не обновляет параметры замороженных слоев. Поскольку градиенты замороженных слоев не нуждаются в вычислении, замораживание весов многих начальных слоев может значительно ускорить тренировку сети. Если новый набор данных мал, то замораживание более ранних сетевых уровней также может предотвратить перенастройку этих уровней в новый набор данных.

Извлеките слои и соединения графика слоев и выберите слои для замораживания. В GoogLeNet первые 10 уровней составляют начальный «стебель» сети. Используйте функцию поддержки freezeWeights, чтобы установить темпы обучения в ноль в первых 10 слоях. Используйте вспомогательную функцию createLgraphUsingConnections для повторного подключения всех слоев в исходном порядке. Новый график слоев содержит те же самые слои, но скорость обучения более ранних слоев установлена равной нулю.

layers = lgraph.Layers;
connections = lgraph.Connections;

layers(1:10) = freezeWeights(layers(1:10));
lgraph = createLgraphUsingConnections(layers,connections);

Железнодорожная сеть

Сеть требует входных изображений размера 224 на 224 на 3, но у изображений в хранилище данных изображения есть различные размеры. Используйте хранилище данных дополненного изображения для автоматического изменения размеров обучающих изображений. Укажите дополнительные операции по дополнению обучающих изображений: случайным образом переверните обучающие изображения вдоль вертикальной оси и произвольно переместите их до 30 пикселей и масштабируйте до 10% по горизонтали и вертикали. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменять размер изображений проверки без дальнейшего увеличения данных, используйте хранилище данных дополненного изображения без указания дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Укажите параметры обучения. Набор InitialLearnRate до небольшого значения, чтобы замедлить обучение в перенесенных слоях, которые еще не заморожены. На предыдущем шаге вы увеличили коэффициенты скорости обучения для последнего обучаемого уровня, чтобы ускорить обучение на новых конечных уровнях. Это сочетание настроек скорости обучения приводит к быстрому обучению в новых слоях, более медленному обучению в средних слоях и отсутствию обучения в более ранних замороженных слоях.

Укажите количество эпох для обучения. При выполнении трансферного обучения не нужно тренироваться на столько же эпох. Эпоха - это полный цикл обучения по всему набору данных обучения. Укажите размер мини-партии и данные проверки. Вычислите точность проверки один раз в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучение сети с использованием данных обучения. По умолчанию trainNetwork использует графический процессор, если он доступен. Для этого требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox). В противном случае trainNetwork использует ЦП. Можно также указать среду выполнения с помощью 'ExecutionEnvironment' аргумент пары имя-значение trainingOptions. Поскольку набор данных настолько мал, обучение происходит быстро.

net = trainNetwork(augimdsTrain,lgraph,options);

Классифицировать изображения проверки

Классифицируйте изображения проверки с помощью отлаженной сети и рассчитайте точность классификации.

[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 0.9000

Отображение четырех образцов проверочных изображений с предсказанными метками и предсказанными вероятностями изображений, имеющих эти метки.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

Ссылки

[1] Сегеди, Кристиан, Вэй Лю, Янцин Цзя, Пьер Серманет, Скотт Рид, Драгомир Ангуэлов, Думитру Эрхан, Венсан Ванхуке и Эндрю Рабинович. «Углубляюсь со свертками.» В материалах конференции IEEE по компьютерному зрению и распознаванию образов, стр. 1-9. 2015.

См. также

| | | | | | | | | |

Связанные темы