Этот пример показывает, как использовать передачу, учащуюся переобучать сверточную нейронную сеть, чтобы классифицировать новый набор изображений.
Предварительно обученные сети классификации изображений были обучены на более чем миллионе изображений и могут классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, кофейная кружка, карандаш и многие животные. Сети изучили богатые представления функции для широкого спектра изображений. Сеть берет изображение в качестве входа, и затем выводит метку для объекта в изображении вместе с вероятностями для каждой из категорий объектов.
Изучение передачи обычно используется в применении глубокого обучения. Можно взять предварительно обученную сеть и использовать ее в качестве отправной точки, чтобы изучить новую задачу. Подстройка сети с передачей, учащейся, обычно намного быстрее и легче, чем обучение сети с нуля со случайным образом инициализированными весами. Можно быстро передать изученные функции новой задаче с помощью меньшего числа учебных изображений.
Разархивируйте и загрузите новые изображения как datastore изображений. Этот очень небольшой набор данных содержит только 75 изображений. Разделите данные на наборы данных обучения и валидации. Используйте 70% изображений для обучения и 30% для валидации.
unzip('MerchData.zip'); imds = imageDatastore('MerchData', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
Загрузите предварительно обученную сеть GoogLeNet. Если Модель Deep Learning Toolbox™ для пакета Сетевой поддержки GoogLeNet не установлена, то программное обеспечение обеспечивает ссылку на загрузку.
Чтобы попробовать различную предварительно обученную сеть, откройте этот пример в MATLAB® и выберите различную сеть. Например, можно попробовать squeezenet
, сеть, которая еще быстрее, чем googlenet
. Можно запустить этот пример с другими предварительно обученными сетями. Для списка всех доступных сетей смотрите Предварительно обученные сети Загрузки.
net = googlenet;
Используйте analyzeNetwork
, чтобы отобразить интерактивную визуализацию сетевой архитектуры и подробной информации о сетевых слоях.
analyzeNetwork(net)
Первый элемент свойства Layers
сети является входным слоем изображений. Этот слой требует входных изображений размера 224 224 3, где 3 количество цветовых каналов.
net.Layers(1)
ans = ImageInputLayer with properties: Name: 'data' InputSize: [224 224 3] Hyperparameters DataAugmentation: 'none' Normalization: 'zerocenter' AverageImage: [224x224x3 single]
inputSize = net.Layers(1).InputSize;
Сверточные слои сетевого извлечения отображают функции что последний learnable слой и итоговое использование слоя классификации, чтобы классифицировать входное изображение. Эти два слоя, 'loss3-classifier'
и 'output'
в GoogLeNet, содержат информацию о том, как сочетать функции, которые сеть извлекает в вероятности класса, значение потерь и предсказанные метки. Чтобы переобучить предварительно обученную сеть, чтобы классифицировать новые изображения, замените эти два слоя на новые слои, адаптированные к новому набору данных.
Извлеките график слоя от обучившего сеть. Если сеть является объектом SeriesNetwork
, таким как AlexNet, VGG-16 или VGG-19, то преобразовывают список слоев в net.Layers
к графику слоя.
if isa(net,'SeriesNetwork') lgraph = layerGraph(net.Layers); else lgraph = layerGraph(net); end
Найдите, что имена этих двух слоев заменяют. Можно сделать это вручную, или можно использовать функцию поддержки findLayersToReplace, чтобы найти эти слои автоматически.
[learnableLayer,classLayer] = findLayersToReplace(lgraph); [learnableLayer,classLayer]
ans = 1x2 Layer array with layers: 1 'loss3-classifier' Fully Connected 1000 fully connected layer 2 'output' Classification Output crossentropyex with 'tench' and 999 other classes
В большинстве сетей последний слой с learnable весами является полносвязным слоем. Замените этот полносвязный слой на новый полносвязный слой с количеством выходных параметров, равных количеству классов в новом наборе данных (5 в этом примере). В некоторых сетях, таких как SqueezeNet, последний learnable слой является сверточным слоем 1 на 1 вместо этого. В этом случае замените сверточный слой на новый сверточный слой с количеством фильтров, равных количеству классов. Чтобы учиться быстрее в новом слое, чем в переданных слоях, увеличьте факторы темпа обучения слоя.
numClasses = numel(categories(imdsTrain.Labels)); if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer') newLearnableLayer = fullyConnectedLayer(numClasses, ... 'Name','new_fc', ... 'WeightLearnRateFactor',10, ... 'BiasLearnRateFactor',10); elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer') newLearnableLayer = convolution2dLayer(1,numClasses, ... 'Name','new_conv', ... 'WeightLearnRateFactor',10, ... 'BiasLearnRateFactor',10); end lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
Слой классификации задает выходные классы сети. Замените слой классификации на новый без меток класса. trainNetwork
автоматически устанавливает выходные классы слоя в учебное время.
newClassLayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
Чтобы проверять, что новые слои соединяются правильно, постройте новый график слоя и увеличьте масштаб последних слоев сети.
figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]); plot(lgraph) ylim([0,10])
Сеть теперь готова быть переобученной на новом наборе изображений. Опционально, можно "заморозить" веса более ранних слоев в сети, установив темпы обучения в тех слоях обнулить. Во время обучения trainNetwork
не обновляет параметры блокированных слоев. Поскольку градиенты блокированных слоев не должны быть вычислены, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение. Если новый набор данных является небольшим, то замораживание более ранних сетевых слоев может также препятствовать тому, чтобы те слои сверхсоответствовали к новому набору данных.
Извлеките слои и связи графика слоя и выбора который слои заморозиться. В GoogLeNet первые 10 слоев разбирают начальную 'основу' сети. Используйте функцию поддержки freezeWeights, чтобы обнулить темпы обучения в первых 10 слоях. Используйте функцию поддержки createLgraphUsingConnections, чтобы повторно подключить все слои в первоначальном заказе. Новый график слоя содержит те же слои, но с темпами обучения более ранних обнуленных слоев.
layers = lgraph.Layers; connections = lgraph.Connections; layers(1:10) = freezeWeights(layers(1:10)); lgraph = createLgraphUsingConnections(layers,connections);
Сеть требует входных изображений размера 224 224 3, но изображения в datastore изображений имеют различные размеры. Используйте увеличенный datastore изображений, чтобы автоматически изменить размер учебных изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом инвертируйте учебные изображения вдоль вертикальной оси и случайным образом переведите их до 30 пикселей и масштабируйте их до 10% горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.
pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... 'DataAugmentation',imageAugmenter);
Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Задайте опции обучения. Установите InitialLearnRate
на маленькое значение замедлять изучение в переданных слоях, которые уже не замораживаются. На предыдущем шаге вы увеличили факторы темпа обучения для последнего learnable слоя, чтобы ускорить изучение в новых последних слоях. Эта комбинация настроек темпа обучения приводит к быстрому изучению в новых слоях, медленнее учась в средних слоях и никакое изучение в ранее, блокированные слои.
Задайте номер эпох, чтобы обучаться для. При использовании обучение с переносом вы не должны обучаться для как много эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Задайте мини-пакетный размер и данные о валидации. Вычислите точность валидации однажды в эпоху.
miniBatchSize = 10; valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',6, ... 'InitialLearnRate',3e-4, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',valFrequency, ... 'Verbose',false, ... 'Plots','training-progress');
Обучите сеть с помощью данных тренировки. По умолчанию trainNetwork
использует графический процессор, если вы доступны (требует Parallel Computing Toolbox™, и CUDA® включил графический процессор с, вычисляют возможность 3.0 или выше). В противном случае trainNetwork
использует центральный процессор. Можно также задать среду выполнения при помощи аргумента пары "имя-значение" 'ExecutionEnvironment'
trainingOptions
. Поскольку набор данных является настолько небольшим, учебный быстро.
net = trainNetwork(augimdsTrain,lgraph,options);
Классифицируйте изображения валидации с помощью подстроенной сети и вычислите точность классификации.
[YPred,probs] = classify(net,augimdsValidation); accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 0.9000
Отобразите четыре демонстрационных изображения валидации с предсказанными метками и предсказанными вероятностями изображений, имеющих те метки.
idx = randperm(numel(imdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
[1] Szegedy, христианин, Вэй Лю, Янцин Цзя, Пьер Сермане, Скотт Рид, Драгомир Ангуелов, Dumitru Erhan, Винсент Вэнхук и Эндрю Рэбинович. "Идя глубже со свертками". В Продолжениях конференции по IEEE по компьютерному зрению и распознаванию образов, стр 1-9. 2015.
[2] Модель BVLC GoogLeNet. https://github.com/BVLC/caffe/tree/master/models/bvlc_googlenet
DAGNetwork
| alexnet
| analyzeNetwork
| googlenet
| importCaffeLayers
| importCaffeNetwork
| layerGraph
| plot
| trainNetwork
| vgg16
| vgg19