exponenta event banner

Возобновить обучение из сети контрольных точек

В этом примере показано, как сохранить сети контрольных точек во время обучения сети глубокого обучения и возобновить обучение из ранее сохраненной сети.

Загрузить данные образца

Загрузите образец данных в виде массива 4-D. digitTrain4DArrayData загружает набор обучающих цифр 4-D виде данных массива. XTrain множество 28 на 28 на 1 на 5000, где 28 высота, и 28 ширина изображений. 1 - количество каналов, а 5000 - количество синтетических изображений рукописных цифр. YTrain - категориальный вектор, содержащий метки для каждого наблюдения.

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

Отображение некоторых изображений в XTrain.

figure;
perm = randperm(size(XTrain,4),20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

Определение сетевой архитектуры

Определите архитектуру нейронной сети.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2) 
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Укажите параметры обучения и сеть обучения

Укажите параметры обучения для стохастического градиентного спуска с импульсом (SGDM) и укажите путь для сохранения сетей контрольных точек.

checkpointPath = pwd;
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',20, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Обучение сети. trainNetwork использует графический процессор, если он доступен. Если доступный графический процессор отсутствует, он использует CPU. trainNetwork сохраняет одну сеть контрольных точек в течение каждой эпохи и автоматически присваивает уникальные имена файлам контрольных точек.

net1 = trainNetwork(XTrain,YTrain,layers,options);

Загрузить сеть контрольных точек и возобновить обучение

Предположим, что тренировки были прерваны и не завершились. Вместо перезапуска обучения с самого начала можно загрузить последнюю сеть контрольных точек и возобновить обучение с этой точки. trainNetwork сохранение файлов контрольных точек с именами файлов в форме net_checkpoint__195__2018_07_13__11_59_10.mat, где 195 - номер итерации, 2018_07_13 - дата, и 11_59_10 время trainNetwork сохранил сеть. Сеть контрольных точек имеет имя переменной net.

Загрузите сеть контрольных точек в рабочую область.

load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')

Укажите варианты обучения и сократите максимальное количество периодов. Можно также настроить другие варианты обучения, такие как начальная скорость обучения.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',15, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Возобновите обучение с использованием уровней сети контрольных точек, загруженных с новыми параметрами обучения. Если сеть контрольных точек является сетью DAG, используйте layerGraph(net) в качестве аргумента вместо net.Layers.

net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

См. также

|

Связанные примеры

Подробнее