Этот пример показывает, как подстроить предварительно обученную сеть GoogLeNet, чтобы классифицировать новый набор изображений. Этот процесс называется изучением передачи и обычно намного быстрее и легче, чем обучение новой сети, потому что можно применить изученные функции к новой задаче с помощью меньшего числа учебных изображений. Чтобы в интерактивном режиме подготовить сеть к изучению передачи, используйте Deep Network Designer.
Загрузите предварительно обученную сеть GoogLeNet. Если необходимо загрузить сеть, используйте ссылку на загрузку.
net = googlenet;
Открытый Deep Network Designer.
deepNetworkDesigner
Нажмите Import и выберите сеть из рабочей области. Deep Network Designer отображает уменьшивший масштаб представление целой сети. Исследуйте сетевой график. Чтобы увеличить масштаб с мышью, используйте колесо Ctrl+scroll.
Чтобы переобучить предварительно обученную сеть, чтобы классифицировать новые изображения, замените последние слои на новые слои, адаптированные к новому набору данных. Необходимо изменить количество классов, чтобы оно совпадало с вашими данными.
Перетащите новый fullyConnectedLayer из Библиотеки Слоя на холст. Отредактируйте OutputSize
к количеству классов в новых данных, в этом примере, 5.
Отредактируйте темпы обучения, чтобы учиться быстрее в новых слоях, чем в переданных слоях. Установите WeightLearnRateFactor
и BiasLearnRateFactor
к 10. Удалите последнее, полностью соединенное, и соедините свой новый слой вместо этого.
Замените выходной слой. Прокрутите в конец Библиотеки Слоя и перетащите новый classificationLayer на холст. Удалите исходный слой output
и соедините свой новый слой вместо этого.
Чтобы убедиться ваша отредактированная сеть готова к обучению, нажмите Analyze и обеспечьте Нейронной сети для глубокого обучения нулевые ошибки отчетов Анализатора.
Возвратитесь к Deep Network Designer и нажмите Export. Deep Network Designer экспортирует сеть в новую переменную под названием lgraph_1
, содержащий отредактированные сетевые слои. Можно теперь предоставить переменную слоя к функции trainNetwork
. Можно также сгенерировать код MATLAB®, который воссоздает сетевую архитектуру и возвращает ее как объект layerGraph
или массив Layer
в рабочем пространстве MATLAB.
Разархивируйте и загрузите новые изображения как datastore изображений. Разделите данные на 70% данных тренировки и 30%-х данных о валидации.
unzip('MerchData.zip'); imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
Измените размер изображений, чтобы совпадать с входным размером предварительно обученной сети.
augimdsTrain = augmentedImageDatastore([224 224],imdsTrain); augimdsValidation = augmentedImageDatastore([224 224],imdsValidation);
Задайте опции обучения.
Задайте мини-пакетный размер, то есть, сколько изображений, чтобы использовать в каждой итерации.
Задайте небольшое количество эпох. Эпоха является полным учебным циклом на целом обучающем наборе данных. Для изучения передачи вы не должны обучаться для как много эпох. Переставьте данные каждая эпоха.
Установите InitialLearnRate
на маленькое значение замедлять изучение в переданных слоях.
Задайте данные о валидации и маленькую частоту валидации.
Включите учебный график контролировать прогресс, в то время как вы обучаетесь.
options = trainingOptions('sgdm', ... 'MiniBatchSize',10, ... 'MaxEpochs',6, ... 'Shuffle','every-epoch', ... 'InitialLearnRate',1e-4, ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',6, ... 'Verbose',false, ... 'Plots','training-progress');
Чтобы обучить сеть, предоставьте слои, экспортируемые из приложения, lgraph_1
, учебных изображений и опций, к функции trainNetwork
. По умолчанию trainNetwork
использует графический процессор при наличии (требует Parallel Computing Toolbox™). В противном случае это использует центральный процессор. Обучение быстро, потому что набор данных является настолько небольшим.
netTransfer = trainNetwork(augimdsTrain,lgraph_1,options);
Классифицируйте изображения валидации с помощью подстроенной сети и вычислите точность классификации.
[YPred,probs] = classify(netTransfer,augimdsValidation); accuracy = mean(YPred == imdsValidation.Labels)
accuracy = 1
Отобразите четыре демонстрационных изображения валидации с предсказанными метками и предсказанными вероятностями.
idx = randperm(numel(augimdsValidation.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(imdsValidation,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
Чтобы узнать больше и попробовать другие предварительно обученные сети, смотрите Deep Network Designer.