В этом примере показано, как сохранить сети контрольной точки, в то время как обучение нейронная сеть для глубокого обучения и возобновляет обучение от ранее сохраненной сети.
Загрузите выборочные данные как 4-D массив. digitTrain4DArrayData
загружает набор обучающих данных цифры как 4-D данные массива. XTrain
28 28 1 5 000 массивов, где 28 высота, и 28 ширина изображений. 1 количество каналов, и 5000 количество синтетических изображений рукописных цифр. YTrain
категориальный вектор, содержащий метки для каждого наблюдения.
[XTrain,YTrain] = digitTrain4DArrayData; size(XTrain)
ans = 1×4
28 28 1 5000
Отобразите некоторые изображения в XTrain
.
figure; perm = randperm(size(XTrain,4),20); for i = 1:20 subplot(4,5,i); imshow(XTrain(:,:,:,perm(i))); end
Задайте архитектуру нейронной сети.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer averagePooling2dLayer(7) fullyConnectedLayer(10) softmaxLayer classificationLayer];
Задайте опции обучения для стохастического градиентного спуска с импульсом (SGDM) и задайте путь для сохранения сетей контрольной точки.
checkpointPath = pwd; options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',20, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Обучите сеть. trainNetwork
использует графический процессор, если существует одно доступное. Если нет никакого доступного графического процессора, то это использует центральный процессор. trainNetwork
сохраняет одну сеть контрольной точки каждая эпоха и автоматически присваивает уникальные имена файлам контрольной точки.
net1 = trainNetwork(XTrain,YTrain,layers,options);
Предположим, что обучение было прервано и не завершалось. Вместо того, чтобы перезапускать обучение с начала, можно загрузить последнюю сеть контрольной точки и возобновить обучение от той точки. trainNetwork
сохраняет файлы контрольной точки с именами файлов на форме net_checkpoint__195__2018_07_13__11_59_10.mat
, где 195 номер итерации, 2018_07_13
дата и 11_59_10
время trainNetwork
сохраненный сеть. Сеть контрольной точки имеет имя переменной net
.
Загрузите сеть контрольной точки в рабочую область.
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
Задайте опции обучения и сократите максимальное количество эпох. Можно также настроить другие опции обучения, такие как начальная скорость обучения.
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',15, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Возобновите обучение с помощью слоев сети контрольной точки, которую вы загрузили с новыми опциями обучения. Если сеть контрольной точки является сетью DAG, то используйте layerGraph(net)
в качестве аргумента вместо net.Layers
.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
trainingOptions
| trainNetwork