В этом примере показано, как подстроить предварительно обученную сеть GoogLeNet, чтобы классифицировать новый набор изображений. Этот процесс называется изучением передачи и обычно намного быстрее и легче, чем обучение новой сети, потому что можно применить изученные функции к новой задаче с помощью меньшего числа учебных изображений. Чтобы в интерактивном режиме подготовить сеть к изучению передачи, используйте Deep Network Designer.
Загрузите предварительно обученную сеть GoogLeNet. Если необходимо загрузить сеть, используйте ссылку на загрузку.
net = googlenet;
Открытый Deep Network Designer.
deepNetworkDesigner
Нажмите Import и выберите сеть из рабочей области. Deep Network Designer отображает уменьшивший масштаб представление целой сети. Исследуйте сетевой график. Чтобы увеличить масштаб с мышью, используйте колесо Ctrl+scroll.
Чтобы переобучить предварительно обученную сеть, чтобы классифицировать новые изображения, замените последние слои на новые слои, адаптированные к новому набору данных. Необходимо изменить количество классов, чтобы оно совпадало с вашими данными.
Перетащите новый fullyConnectedLayer из Библиотеки Слоя на холст. Отредактируйте OutputSize
к количеству классов в новых данных, в этом примере, 5.
Отредактируйте темпы обучения, чтобы учиться быстрее в новых слоях, чем в переданных слоях. Установите WeightLearnRateFactor
и BiasLearnRateFactor
к 10. Удалите последнее, полностью соединенное, и соедините свой новый слой вместо этого.
Замените выходной слой. Прокрутите в конец Библиотеки Слоя и перетащите новый classificationLayer на холст. Удалите исходный output
слой и подключение ваш новый слой вместо этого.
Чтобы убедиться ваша отредактированная сеть готова к обучению, нажмите Analyze и обеспечьте Нейронной сети для глубокого обучения нулевые ошибки отчетов Анализатора.
Возвратитесь к Deep Network Designer и нажмите Export. Deep Network Designer экспортирует сеть в новую переменную под названием lgraph_1
содержа отредактированные сетевые слои. Можно теперь предоставить переменную слоя к trainNetwork
функция. Можно также сгенерировать код MATLAB®, который воссоздает сетевую архитектуру и возвращает ее как layerGraph
возразите или Layer
массив в рабочем пространстве MATLAB.
Разархивируйте и загрузите новые изображения как datastore изображений. Разделите данные на 70% обучающих данных и 30%-х данных о валидации.
unzip('MerchData.zip'); imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Измените размер изображений, чтобы совпадать с входным размером предварительно обученной сети.
augimdsTrain = augmentedImageDatastore([224 224],imdsTrain); augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
Задайте опции обучения.
Задайте мини-пакетный размер, то есть, сколько изображений, чтобы использовать в каждой итерации.
Задайте небольшое количество эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Для изучения передачи вы не должны обучаться для как много эпох. Переставьте данные каждая эпоха.
Установите InitialLearnRate
к маленькому значению, чтобы замедлить изучение в переданных слоях.
Задайте данные о валидации и маленькую частоту валидации.
Включите учебный график контролировать прогресс, в то время как вы обучаетесь.
options = trainingOptions('sgdm', ... 'MiniBatchSize',10, ... 'MaxEpochs',6, ... 'Shuffle','every-epoch', ... 'InitialLearnRate',1e-4, ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',6, ... 'Verbose',false, ... 'Plots','training-progress');
Чтобы обучить сеть, предоставьте слои, экспортируемые из приложения, lgraph_1
, учебные изображения и опции, к trainNetwork
функция. По умолчанию, trainNetwork
использует графический процессор при наличии (требует Parallel Computing Toolbox™). В противном случае это использует центральный процессор. Обучение быстро, потому что набор данных так мал.
netTransfer = trainNetwork(augimdsTrain,lgraph_1,options);
Классифицируйте изображения валидации с помощью подстроенной сети и вычислите точность классификации.
[YPred,probs] = classify(netTransfer,augimdsValidation); accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 1
Отобразите четыре демонстрационных изображения валидации с предсказанными метками и предсказанными вероятностями.
idx = randperm(numel(augimdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
Чтобы узнать больше и попробовать другие предварительно обученные сети, смотрите Deep Network Designer.