В этом примере показано, как обучить сеть, чтобы классифицировать изображения объектов с помощью циклического расписания скорости обучения и создать снимки ensembling для лучшей тестовой точности. В примере вы изучаете, как использовать косинусную функцию для расписания скорости обучения, взять снимки состояния сети во время обучения создать ансамбль модели и добавить регуляризацию L2-нормы (затухание веса) к учебной потере.
Этот пример обучает остаточную сеть [1] на наборе данных CIFAR-10 [2] с пользовательской циклической скоростью обучения: для каждой итерации решатель использует скорость обучения, данную переключенной косинусной функцией [3] alpha(t) = (alpha0/2)*cos(pi*mod(t-1,T/M)/(T/M)+1)
, где t
номер итерации, T
общее количество учебных итераций, alpha0
начальная скорость обучения и M
количество циклов/снимков состояния. Это расписание скорости обучения эффективно разделяет учебный процесс в M
циклы. Каждый цикл начинается с большой скорости обучения, которая затухает монотонно, обеспечивая сеть, чтобы исследовать различные локальные минимумы. В конце каждого учебного цикла вы берете снимок состояния сети (то есть, вы сохраняете модель в этой итерации), и более позднее среднее значение предсказания всех моделей снимка состояния, также известных как ensembling [4] снимка состояния, чтобы улучшить точность завершающего испытания.
Загрузите набор данных CIFAR-10 [2]. Набор данных содержит 60 000 изображений. Каждое изображение находится 32 32 в размере и имеет три цветовых канала (RGB). Размер набора данных составляет 175 Мбайт. В зависимости от вашего интернет-соединения может занять время процесс загрузки.
datadir = tempdir; downloadCIFARData(datadir);
Загрузите обучение CIFAR-10 и протестируйте изображения как 4-D массивы. Набор обучающих данных содержит 50 000 изображений, и набор тестов содержит 10 000 изображений.
[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir); classes = categories(YTrain); numClasses = numel(classes);
Можно отобразить случайную выборку учебных изображений с помощью следующего кода.
figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),'ThumbnailSize',[96,96]);
imshow(im)
Создайте augmentedImageDatastore
возразите, чтобы использовать для сетевого обучения. Во время обучения datastore случайным образом инвертирует учебные изображения вдоль вертикальной оси и случайным образом переводит их до четырех пикселей горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ... 'DataAugmentation',imageAugmenter);
Создайте остаточную сеть [1] с шестью стандартными сверточными модулями (два модуля на этап) и ширина 16. Общая сетевая глубина 2*6+2 = 14. Кроме того, задайте среднее изображение с помощью 'Mean'
опция в изображении ввела слой.
netWidth = 16; layers = [ imageInputLayer(imageSize,'Name','input','Mean', mean(XTrain,4)) convolution2dLayer(3,netWidth,'Padding','same','Name','convInp') batchNormalizationLayer('Name','BNInp') reluLayer('Name','reluInp') convolutionalUnit(netWidth,1,'S1U1') additionLayer(2,'Name','add11') reluLayer('Name','relu11') convolutionalUnit(netWidth,1,'S1U2') additionLayer(2,'Name','add12') reluLayer('Name','relu12') convolutionalUnit(2*netWidth,2,'S2U1') additionLayer(2,'Name','add21') reluLayer('Name','relu21') convolutionalUnit(2*netWidth,1,'S2U2') additionLayer(2,'Name','add22') reluLayer('Name','relu22') convolutionalUnit(4*netWidth,2,'S3U1') additionLayer(2,'Name','add31') reluLayer('Name','relu31') convolutionalUnit(4*netWidth,1,'S3U2') additionLayer(2,'Name','add32') reluLayer('Name','relu32') averagePooling2dLayer(8,'Name','globalPool') fullyConnectedLayer(10,'Name','fcFinal') ]; lgraph = layerGraph(layers); lgraph = connectLayers(lgraph,'reluInp','add11/in2'); lgraph = connectLayers(lgraph,'relu11','add12/in2'); skip1 = [ convolution2dLayer(1,2*netWidth,'Stride',2,'Name','skipConv1') batchNormalizationLayer('Name','skipBN1')]; lgraph = addLayers(lgraph,skip1); lgraph = connectLayers(lgraph,'relu12','skipConv1'); lgraph = connectLayers(lgraph,'skipBN1','add21/in2'); lgraph = connectLayers(lgraph,'relu21','add22/in2'); skip2 = [ convolution2dLayer(1,4*netWidth,'Stride',2,'Name','skipConv2') batchNormalizationLayer('Name','skipBN2')]; lgraph = addLayers(lgraph,skip2); lgraph = connectLayers(lgraph,'relu22','skipConv2'); lgraph = connectLayers(lgraph,'skipBN2','add31/in2'); lgraph = connectLayers(lgraph,'relu31','add32/in2');
Постройте архитектуру ResNet.
figure; plot(lgraph)
Создайте dlnetwork
объект от графика слоев.
dlnet = dlnetwork(lgraph);
Создайте функцию помощника modelGradients
, перечисленный в конце примера. Функция берет в dlnetwork
объект dlnet
and
мини-пакет входных данных dlX
с соответствием маркирует Y,
и возвращает градиенты потери относительно настраиваемых параметров в dlnet
. Эта функция также возвращает потерю и состояние nonlearnable параметров сети в данной итерации.
Задайте опции обучения. Обучайтесь в течение 200 эпох с мини-пакетным размером 64.
numEpochs = 200; miniBatchSize = 64; numObservations = numel(YTrain); velocity = []; momentum = 0.9; weightDecay = 1e-4;
Задайте опции обучения, характерные для циклической скорости обучения. Alpha0
начальная скорость обучения и numSnapshots
количество циклов или создает снимки взятый во время обучения.
alpha0 = 0.1;
numSnapshots = 5;
epochsPerSnapshot = numEpochs./numSnapshots;
iterationsPerSnapshot = ceil(numObservations./miniBatchSize)*numEpochs./numSnapshots;
modelPrefix = "SnapshotEpoch";
Визуализируйте процесс обучения в графике.
plots = "training-progress";
Инициализируйте учебную фигуру.
if plots == "training-progress" [lossLine,learnRateLine] = plotLossAndLearnRate(); end
Используйте minibatchqueue
обработать и управлять мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch
(заданный в конце этого примера) к одногорячему кодируют метки класса.
Формат данные изображения с размерностью маркирует 'SSCB'
(пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат в метки класса.
Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue
объект преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
augimdsTrain.MiniBatchSize = miniBatchSize; mbqTrain = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Обучите модель с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте datastore, цикл по мини-пакетам данных, и сохраните модель (снимок состояния), если текущая эпоха является кратной epochsPerSnapshot
. В конце каждой эпохи отобразите прогресс обучения. Для каждого мини-пакета:
Оцените градиенты модели и потерю с помощью dlfeval
и modelGradients
функция.
Обновите состояние nonlearnable параметров сети.
Определите скорость обучения для циклического расписания скорости обучения.
Обновите сетевые параметры с помощью sgdmupdate
функция.
Постройте потерю и скорость обучения в каждой итерации.
В данном примере обучение заняло приблизительно 14 часов на NVIDIA™ TITAN RTX.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. shuffle(mbqTrain); % Save snapshot model. if ~mod(epoch,epochsPerSnapshot) save(modelPrefix + epoch + ".mat",'dlnet'); end % Loop over mini-batches. while hasdata(mbqTrain) iteration = iteration + 1; % Read mini-batch of data. [dlX,dlY] = next(mbqTrain); % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients, loss, state] = dlfeval(@modelGradients,dlnet,dlX,dlY,weightDecay); % Update the state of nonlearnable parameters. dlnet.State = state; % Determine learning rate for cyclical learning rate schedule. learnRate = 0.5*alpha0*(cos((pi*mod(iteration-1,iterationsPerSnapshot)./iterationsPerSnapshot))+1); % Update the network parameters using the SGDM optimizer. [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lossLine,iteration,double(gather(extractdata(loss)))) addpoints(learnRateLine, iteration, learnRate); sgtitle("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Объедините снимки состояния M сети, взятой во время обучения сформировать итоговый ансамбль и протестировать точность классификации модели. Предсказания ансамбля соответствуют среднему значению выхода полносвязного слоя из всех отдельных моделей M.
Протестируйте модель на тестовых данных, которым предоставляют набор данных CIFAR-10. Управляйте набором тестовых данных с помощью minibatchqueue
объект с той же установкой как обучающие данные.
augimdsTest = augmentedImageDatastore(imageSize,XTest,YTest); augimdsTest.MiniBatchSize = miniBatchSize; mbqTest = minibatchqueue(augimdsTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn', @preprocessMiniBatch,... 'MiniBatchFormat',{'SSCB',''});
Оцените точность каждой сети снимка состояния. Используйте modelPredictions
функция, определяемая в конце этого примера, чтобы выполнить итерации по всем данным в наборе тестовых данных. Функция возвращает выходной параметр полносвязного слоя из модели, предсказанных классов и сравнения с истинным классом.
modelName = cell(numSnapshots+1,1); fcOutput = zeros(numClasses,numel(YTest),numSnapshots+1); classPredictions = cell(1,numSnapshots+1); modelAccuracy = zeros(numSnapshots+1,1); for m = 1:numSnapshots modelName{m} = modelPrefix + m*epochsPerSnapshot; load(modelName{m} + ".mat"); reset(mbqTest); [fcOutputTest,classPredTest,classCorrTest] = modelPredictions(dlnet,mbqTest,classes); fcOutput(:,:,m) = fcOutputTest; classPredictions{m} = classPredTest; modelAccuracy(m) = 100*mean(classCorrTest); disp(modelName{m} + " accuracy: " + modelAccuracy(m) + "%") end
SnapshotEpoch40 accuracy: 88.35% SnapshotEpoch80 accuracy: 89.93% SnapshotEpoch120 accuracy: 90.51% SnapshotEpoch160 accuracy: 90.33% SnapshotEpoch200 accuracy: 90.63%
Чтобы определить выход сетей ансамбля, вычислите среднее значение полностью связанного выхода каждой сети снимка состояния. Найдите предсказанные классы от сети ансамбля использованием onehotdecode
функция. Сравните с истинными классами, чтобы оценить точность ансамбля.
fcOutput(:,:,end) = mean(fcOutput(:,:,1:end-1),3); classPredictions{end} = onehotdecode(softmax(fcOutput(:,:,end)),classes,1,'categorical'); classCorrEnsemble = classPredictions{end} == YTest'; modelAccuracy(end) = 100*mean(classCorrEnsemble); modelName{end} = "Ensemble model"; disp("Ensemble accuracy: " + modelAccuracy(end) + "%")
Ensemble accuracy: 91.59%
Постройте точность на наборе тестовых данных для всех моделей снимка состояния и модели ансамбля.
figure;bar(modelAccuracy); ylabel('Accuracy (%)'); xticklabels(modelName) xtickangle(45) title('Model accuracy')
modelGradients
функционируйте берет в dlnetwork
объект dlnet
, мини-пакет входных данных dlX
, маркирует Y,
и параметр для затухания веса. Функция возвращает градиенты, потерю и состояние nonlearnable параметров. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients,loss,state] = modelGradients(dlnet,dlX,Y,weightDecay) [dlYPred,state] = forward(dlnet,dlX); dlYPred = softmax(dlYPred); loss = crossentropy(dlYPred, Y); % L2-regularization (weight decay) allParams = dlnet.Learnables(dlnet.Learnables.Parameter == "Weights" | dlnet.Learnables.Parameter == "Scale",:).Value; l2Norm = cellfun(@(x) sum(x.^2,'All'),allParams,'UniformOutput',false); l2Norm = sum(cat(1,l2Norm{:})); loss = loss + weightDecay*0.5*l2Norm; gradients = dlgradient(loss,dlnet.Learnables); end
modelPredictions
функционируйте берет в качестве входа dlnetwork
объект dlnet
, minibatchqueue
из входных данных mbq
, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue
. Функция использует onehotdecode
функционируйте, чтобы найти предсказанный класс с самым высоким счетом и затем сравнить предсказание с истинным классом. Функция возвращает сетевой выходной параметр, предсказания класса и вектор из единиц и нулей, который представляет правильные и неправильные предсказания.
function [rawPredictions,classPredictions,classCorr] = modelPredictions(dlnet,mbq,classes) rawPredictions = []; classPredictions = []; classCorr = []; while hasdata(mbq) [dlX,dlY] = next(mbq); % Make predictions dlYPred = predict(dlnet,dlX); rawPredictions = [rawPredictions extractdata(gather(dlYPred))]; % Convert network output to probabilities and determine predicted % classes dlYPred = softmax(dlYPred); YPredBatch = onehotdecode(dlYPred,classes,1); classPredictions = [classPredictions YPredBatch]; % Compare predicted and true classes Y = onehotdecode(dlY,classes,1); classCorr = [classCorr YPredBatch == Y]; end end
plotLossAndLearnRate
функционируйте initiliaizes графики для отображения потери и скорости обучения в каждой итерации во время обучения.
function [lossLine, learnRateLine] = plotLossAndLearnRate() figure subplot(2,1,1); lossLine = animatedline('Color',[0.85 0.325 0.098]); title('Loss'); xlabel('Iteration') ylabel('Loss') grid on subplot(2,1,2); learnRateLine = animatedline('Color',[0 0.447 0.741]); title('Learning rate'); xlabel('Iteration') ylabel('Learning rate') grid on end
convolutionalUnit(numF,stride,tag)
функция создает массив слоев с двумя сверточными слоями и соответствующей нормализацией партии. и слоями ReLU. numF
количество сверточных фильтров, stride
шаг первого сверточного слоя и tag
тег, который предварительно ожидается ко всем именам слоя.
function layers = convolutionalUnit(numF,stride,tag) layers = [ convolution2dLayer(3,numF,'Padding','same','Stride',stride,'Name',[tag,'conv1']) batchNormalizationLayer('Name',[tag,'BN1']) reluLayer('Name',[tag,'relu1']) convolution2dLayer(3,numF,'Padding','same','Name',[tag,'conv2']) batchNormalizationLayer('Name',[tag,'BN2'])]; end
preprocessMiniBatch
функция предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.
Извлеките данные о метке из массивов входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения.
Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.
function [X,Y] = preprocessMiniBatch(XCell,YCell) % Extract image data from cell and concatenate X = cat(4,XCell{:}); % Extract label data from cell and concatenate Y = cat(2,YCell{:}); % One-hot encode labels Y = onehotencode(Y,1); end
[1] Он, Kaiming, Сянюй Чжан, Шаоцин Жэнь и Цзянь Сунь. "Глубокая невязка, учащаяся для распознавания изображений". В Продолжениях конференции по IEEE по компьютерному зрению и распознаванию образов, стр 770-778. 2016.
[2] Krizhevsky, Алекс. "Изучая несколько слоев функций от крошечных изображений". (2009). https://www.cs.toronto.edu / ~ kriz/learning-features-2009-TR.pdf
[3] Лощилов, Илья и Франк Хуттер. "Sgdr: Стохастический градиентный спуск с горячими перезапусками". (2016). arXiv предварительно распечатывают arXiv:1608.03983.
[4] Хуан, Гао, Йиксуэн Ли, Джефф Плейсс, Чжуан Лю, Джон Э. Хопкрофт и Килиан К. Вайнбергер. "Создайте снимки ансамбли: Обучите 1, получите m бесплатно". (2017). arXiv предварительно распечатывают arXiv:1704.00109.
dlnetwork
| layerGraph
| dlarray
| sgdmupdate
| dlfeval
| dlgradient
| sigmoid
| minibatchqueue
| onehotencode
| onehotdecode