Возобновите обучение от сети контрольной точки

В этом примере показано, как сохранить сети контрольной точки, в то время как обучение нейронная сеть для глубокого обучения и возобновляет обучение от ранее сохраненной сети.

Загрузка демонстрационных данных

Загрузите выборочные данные как 4-D массив. digitTrain4DArrayData загружает набор обучающих данных цифры как 4-D данные массива. XTrain 28 28 1 5 000 массивов, где 28 высота, и 28 ширина изображений. 1 количество каналов, и 5000 количество синтетических изображений рукописных цифр. YTrain категориальный вектор, содержащий метки для каждого наблюдения.

[XTrain,YTrain] = digitTrain4DArrayData;
size(XTrain)
ans = 1×4

          28          28           1        5000

Отобразите некоторые изображения в XTrain.

figure;
perm = randperm(size(XTrain,4),20);
for i = 1:20
    subplot(4,5,i);
    imshow(XTrain(:,:,:,perm(i)));
end

Архитектура сети Define

Задайте архитектуру нейронной сети.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2) 
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Задайте опции обучения и обучите сеть

Задайте опции обучения для стохастического градиентного спуска с импульсом (SGDM) и задайте путь для сохранения сетей контрольной точки.

checkpointPath = pwd;
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',20, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Обучите сеть. trainNetwork использует графический процессор, если существует одно доступное. Если нет никакого доступного графического процессора, то это использует центральный процессор. trainNetwork сохраняет одну сеть контрольной точки каждая эпоха и автоматически присваивает уникальные имена файлам контрольной точки.

net1 = trainNetwork(XTrain,YTrain,layers,options);

Загрузите сеть контрольной точки и возобновите обучение

Предположим, что обучение было прервано и не завершалось. Вместо того, чтобы перезапускать обучение с начала, можно загрузить последнюю сеть контрольной точки и возобновить обучение от той точки. trainNetwork сохраняет файлы контрольной точки с именами файлов на форме net_checkpoint__195__2018_07_13__11_59_10.mat, где 195 номер итерации, 2018_07_13 дата и 11_59_10 время trainNetwork сохраненный сеть. Сеть контрольной точки имеет имя переменной net.

Загрузите сеть контрольной точки в рабочую область.

load('net_checkpoint__195__2018_07_13__11_59_10.mat','net')

Задайте опции обучения и сократите максимальное количество эпох. Можно также настроить другие опции обучения, такие как начальная скорость обучения.

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.1, ...
    'MaxEpochs',15, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Shuffle','every-epoch', ...
    'CheckpointPath',checkpointPath);

Возобновите обучение с помощью слоев сети контрольной точки, которую вы загрузили с новыми опциями обучения. Если сеть контрольной точки является сетью DAG, то используйте layerGraph(net) в качестве аргумента вместо net.Layers.

net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

Смотрите также

|

Связанные примеры

Больше о