В этом примере показано, как сохранить сети контрольных точек во время обучения нейронной сети для глубокого обучения и возобновить обучение из ранее сохраненной сети.
Загрузите выборочные данные как 4-D массив. digitTrain4DArrayData загружает набор обучающих данных цифр как 4-D данные массива. XTrain массив 28 на 28 на 1 на 5000, где 28 высота, и 28 ширина изображений. 1 - количество каналов, а 5000 - количество синтетических изображений рукописных цифр. YTrain - категориальный вектор, содержащий метки для каждого наблюдения.
[XTrain,YTrain] = digitTrain4DArrayData; size(XTrain)
ans = 1×4
28 28 1 5000
Отобразите некоторые изображения в XTrain.
figure; perm = randperm(size(XTrain,4),20); for i = 1:20 subplot(4,5,i); imshow(XTrain(:,:,:,perm(i))); end

Определите архитектуру нейронной сети.
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer(7)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];Задайте опции обучения для стохастического градиентного спуска с импульсом (SGDM) и укажите путь для сохранения сетей контрольных точек.
checkpointPath = pwd; options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',20, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Обучите сеть. trainNetwork использует графический процессор, если он доступен. Если доступного графический процессор нет, то он использует центральный процессор. trainNetwork сохраняет одну сеть контрольных точек каждую эпоху и автоматически присваивает уникальные имена файлам контрольных точек.
net1 = trainNetwork(XTrain,YTrain,layers,options);

Предположим, что обучение было прервано и не завершено. Вместо того, чтобы перезапускать обучение с самого начала, можно загрузить последнюю сеть контрольных точек и возобновить обучение с этой точки. trainNetwork сохраняет файлы контрольных точек с именами файлов в форме net_checkpoint__195__2018_07_13__11_59_10.mat, где 195 - число итерации, 2018_07_13 является датой и 11_59_10 является временем trainNetwork сохранена сеть. Сеть контрольных точек имеет имя переменной net.
Загрузите сеть контрольных точек в рабочую область.
load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')
Задайте опции обучения и уменьшите максимальное количество эпох. Можно также настроить другие опции обучения, такие как начальная скорость обучения.
options = trainingOptions('sgdm', ... 'InitialLearnRate',0.1, ... 'MaxEpochs',15, ... 'Verbose',false, ... 'Plots','training-progress', ... 'Shuffle','every-epoch', ... 'CheckpointPath',checkpointPath);
Возобновите обучение с использованием слоев сети контрольных точек, которые вы загрузили, с новыми опциями обучения. Если сеть контрольных точек является сетью DAG, используйте layerGraph(net) в качестве аргумента вместо net.Layers.
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

trainingOptions | trainNetwork