Обучите нейронную сеть для глубокого обучения классифицировать новые изображения

В этом примере показано, как использовать передачу обучения, чтобы переобучить сверточную нейронную сеть, чтобы классифицировать новый набор изображений.

Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа, и затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.

Передача обучения обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве начальной точки, чтобы изучить новую задачу. Подстройка сети с передачей обучения обычно намного быстрее и легче, чем обучение сети с нуля со случайным образом инициализированными весами. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных изображений.

Загрузка данных

Разархивируйте и загрузите новые изображения как datastore изображений. Этот очень небольшой набор данных содержит только 75 изображений. Разделите данные на наборы данных обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.

unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames'); 
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);

Загрузите предварительно обученную сеть

Загрузите предварительно обученную сеть GoogLeNet. Если Модель Deep Learning Toolbox™ для пакета Сетевой поддержки GoogLeNet не установлена, то программное обеспечение обеспечивает ссылку на загрузку.

Чтобы попробовать различную предварительно обученную сеть, откройте этот пример в MATLAB® и выберите различную сеть. Например, можно попробовать squeezenet, сеть, которая является четной быстрее, чем googlenet. Можно запустить этот пример с другими предварительно обученными сетями. Для списка всех доступных сетей смотрите Предварительно обученные сети Загрузки.

net = googlenet;

Используйте analyzeNetwork отобразить интерактивную визуализацию сетевой архитектуры и подробной информации о слоях сети.

analyzeNetwork(net)

Первый элемент Layers свойство сети является входным слоем изображений. Для сети GoogLeNet этот слой требует входных изображений размера 224 224 3, где 3 количество цветовых каналов. Другие сети могут потребовать входных изображений с различными размерами. Например, сеть Xception требует изображений размера 299 299 3.

net.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

inputSize = net.Layers(1).InputSize;

Замените последние слои

Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'loss3-classifier' и 'output' в GoogLeNet содержите информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть, чтобы классифицировать новые изображения, замените эти два слоя на новые слои, адаптированные к новому набору данных.

Преобразуйте обучивший сеть в график слоев.

lgraph = layerGraph(net);

Найдите, что имена этих двух слоев заменяют. Можно сделать это вручную, или можно использовать функцию поддержки findLayersToReplace, чтобы найти эти слои автоматически.

[learnableLayer,classLayer] = findLayersToReplace(lgraph);
[learnableLayer,classLayer] 
ans = 
  1×2 Layer array with layers:

     1   'loss3-classifier'   Fully Connected         1000 fully connected layer
     2   'output'             Classification Output   crossentropyex with 'tench' and 999 other classes

В большинстве сетей последний слой с learnable весами является полносвязным слоем. Замените этот полносвязный слой на новый полносвязный слой с количеством выходных параметров, равных количеству классов в новом наборе данных (5 в этом примере). В некоторых сетях, таких как SqueezeNet, последний learnable слой является сверточным слоем 1 на 1 вместо этого. В этом случае замените сверточный слой на новый сверточный слой с количеством фильтров, равных количеству классов. Чтобы учиться быстрее в новом слое, чем в переданных слоях, увеличьте факторы скорости обучения слоя.

numClasses = numel(categories(imdsTrain.Labels));

if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        'Name','new_fc', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
    
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',10, ...
        'BiasLearnRateFactor',10);
end

lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork автоматически устанавливает выходные классы слоя в учебное время.

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);

Чтобы проверять, что новые слои соединяются правильно, постройте новый график слоев и увеличьте масштаб последних слоев сети.

figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
plot(lgraph)
ylim([0,10])

Заморозьте начальные слои

Сеть теперь готова быть переобученной на новом наборе изображений. Опционально, можно "заморозить" веса более ранних слоев в сети, установив скорости обучения в тех слоях обнулить. Во время обучения, trainNetwork не обновляет параметры блокированных слоев. Поскольку градиенты блокированных слоев не должны быть вычислены, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных мал, то замораживание более ранних слоев сети может также препятствовать тому, чтобы те слои сверхсоответствовали к новому набору данных.

Извлеките слои и связи графика слоев и выбора который слои заморозиться. В GoogLeNet первые 10 слоев разбирают начальную 'основу' сети. Используйте функцию поддержки freezeWeights, чтобы обнулить скорости обучения в первых 10 слоях. Используйте функцию поддержки createLgraphUsingConnections, чтобы повторно подключить все слои в первоначальном заказе. Новый график слоев содержит те же слои, но со скоростями обучения более ранних обнуленных слоев.

layers = lgraph.Layers;
connections = lgraph.Connections;

layers(1:10) = freezeWeights(layers(1:10));
lgraph = createLgraphUsingConnections(layers,connections);

Обучение сети

Сеть требует входных изображений размера 224 224 3, но изображения в datastore изображений имеют различные размеры. Используйте увеличенный datastore изображений, чтобы автоматически изменить размер учебных изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом инвертируйте учебные изображения вдоль вертикальной оси и случайным образом переведите их до 30 пикселей и масштабируйте их до 10% горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Задайте опции обучения. Установите InitialLearnRate к маленькому значению, чтобы замедлить изучение в переданных слоях, которые уже не замораживаются. На предыдущем шаге вы увеличили факторы скорости обучения для последнего learnable слоя, чтобы ускорить изучение в новых последних слоях. Эта комбинация настроек скорости обучения приводит к быстрому изучению в новых слоях, медленнее учась в средних слоях и никакое изучение в ранее, блокированные слои.

Задайте номер эпох, чтобы обучаться для. При использовании обучение с переносом вы не должны обучаться для как много эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Задайте мини-пакетный размер и данные о валидации. Вычислите точность валидации однажды в эпоху.

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

Обучите сеть с помощью обучающих данных. По умолчанию, trainNetwork использует графический процессор, если вы доступны. Это требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). В противном случае, trainNetwork использует центральный процессор. Можно также задать среду выполнения при помощи 'ExecutionEnvironment' аргумент пары "имя-значение" trainingOptions. Поскольку набор данных так мал, обучение быстро.

net = trainNetwork(augimdsTrain,lgraph,options);

Классифицируйте изображения валидации

Классифицируйте изображения валидации с помощью подстроенной сети и вычислите точность классификации.

[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 0.9000

Отобразите четыре демонстрационных изображения валидации с предсказанными метками и предсказанными вероятностями изображений, имеющих те метки.

idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

Ссылки

[1] Szegedy, христианин, Вэй Лю, Янцин Цзя, Пьер Сермане, Скотт Рид, Драгомир Ангуелов, Dumitru Erhan, Винсент Вэнхук и Эндрю Рэбинович. "Идя глубже со свертками". В Продолжениях конференции по IEEE по компьютерному зрению и распознаванию образов, стр 1-9. 2015.

Смотрите также

| | | | | | | | |

Похожие темы