Настройте Вывод во время обучения нейронной сети для глубокого обучения

Этот пример показывает, как задать выходную функцию, которая запускается в каждой итерации во время обучения глубоких нейронных сетей. Если вы задаете выходные функции при помощи аргумента пары "имя-значение" 'OutputFcn' trainingOptions, то trainNetwork вызывает эти функции однажды запуск обучения после каждой учебной итерации, и однажды после того, как обучение закончилось. Каждый раз, когда выходные функции называются, trainNetwork передает структуру, содержащую информацию, такую как текущий номер итерации, потеря и точность. Можно использовать выходные функции, чтобы отобразить или построить информацию о прогрессе или остановить обучение. Чтобы остановить обучение рано, заставьте свою выходную функцию возвратить true. Если какая-либо выходная функция возвращает true, то обучение закончило and trainNetwork returns последняя сеть.

Чтобы остановиться обучение, когда потеря на наборе валидации прекратит уменьшаться, просто задайте данные о валидации и терпение валидации с помощью 'ValidationData' и аргументов пары "имя-значение" 'ValidationPatience' trainingOptions, соответственно. Терпение валидации является числом раз, что потеря на наборе валидации может быть больше, чем или равняться ранее самой маленькой потере, прежде чем сетевое обучение остановится. Можно добавить дополнительный критерий остановки с помощью выходных функций. Этот пример показывает, как создать выходную функцию, которая останавливает обучение, когда точность классификации на данных о валидации прекращает улучшаться. Выходная функция задана в конце скрипта.

Загрузите данные тренировки, который содержит 5 000 изображений цифр. Отложите 1000 из изображений для сетевой валидации.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Создайте сеть, чтобы классифицировать данные изображения цифры.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Задайте опции для сетевого обучения. Чтобы подтвердить сеть равномерно во время обучения, задайте данные о валидации. Выберите значение 'ValidationFrequency' так, чтобы сеть была подтверждена однажды в эпоху.

Чтобы остановить обучение, когда точность классификации на наборе валидации прекратит улучшаться, задайте stopIfAccuracyNotImproving как выходную функцию. Второй входной параметр stopIfAccuracyNotImproving является числом раз, что точность на наборе валидации может быть меньшей, чем или равняться ранее самой высокой точности, прежде чем сетевое обучение остановится. Выберите любое большое значение для максимального количества эпох, чтобы обучаться. Обучение не должно достигать итоговой эпохи, потому что обучение останавливается автоматически.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

Обучите сеть. Обучение останавливается, когда точность валидации прекращает увеличиваться.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:00 |       10.94% |       13.20% |       2.9520 |       2.5400 |          0.0100 |
|       1 |          31 |       00:00:04 |       71.09% |       75.40% |       0.8453 |       0.8356 |          0.0100 |
|       2 |          62 |       00:00:07 |       91.41% |       89.20% |       0.3514 |       0.4304 |          0.0100 |
|       3 |          93 |       00:00:10 |       96.88% |       94.20% |       0.1887 |       0.2572 |          0.0100 |
|       4 |         124 |       00:00:12 |       99.22% |       96.20% |       0.1189 |       0.1927 |          0.0100 |
|       5 |         155 |       00:00:15 |      100.00% |       96.80% |       0.0880 |       0.1566 |          0.0100 |
|       6 |         186 |       00:00:18 |      100.00% |       97.10% |       0.0614 |       0.1226 |          0.0100 |
|       7 |         217 |       00:00:21 |       99.22% |       97.90% |       0.0566 |       0.1017 |          0.0100 |
|       8 |         248 |       00:00:24 |       99.22% |       98.20% |       0.0476 |       0.0863 |          0.0100 |
|       9 |         279 |       00:00:27 |      100.00% |       98.60% |       0.0334 |       0.0740 |          0.0100 |
|      10 |         310 |       00:00:30 |      100.00% |       98.80% |       0.0267 |       0.0645 |          0.0100 |
|      11 |         341 |       00:00:32 |      100.00% |       98.80% |       0.0226 |       0.0567 |          0.0100 |
|      12 |         372 |       00:00:36 |      100.00% |       99.20% |       0.0195 |       0.0503 |          0.0100 |
|      13 |         403 |       00:00:38 |      100.00% |       99.30% |       0.0171 |       0.0453 |          0.0100 |
|      14 |         434 |       00:00:41 |      100.00% |       99.40% |       0.0154 |       0.0417 |          0.0100 |
|      15 |         465 |       00:00:44 |      100.00% |       99.50% |       0.0142 |       0.0391 |          0.0100 |
|      16 |         496 |       00:00:47 |      100.00% |       99.50% |       0.0131 |       0.0371 |          0.0100 |
|      17 |         527 |       00:00:49 |      100.00% |       99.50% |       0.0122 |       0.0355 |          0.0100 |
|      18 |         558 |       00:00:52 |      100.00% |       99.50% |       0.0114 |       0.0343 |          0.0100 |
|======================================================================================================================|

Задайте выходную функцию

Задайте выходную функцию stopIfAccuracyNotImproving(info,N), , который останавливает сетевое обучение, если лучшая точность классификации на данных о валидации не улучшается для валидаций сети N подряд. Этот критерий подобен встроенному критерию остановки с помощью потери валидации, за исключением того, что это применяется к точности классификации вместо потери.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

Смотрите также

|

Похожие темы