Настройте Выход во время обучения нейронной сети для глубокого обучения

В этом примере показано, как задать выходную функцию, которая запускается в каждой итерации во время обучения глубоких нейронных сетей. Если вы задаете выходные функции при помощи 'OutputFcn' аргумент пары "имя-значение" trainingOptions, затем trainNetwork вызывает эти функции однажды запуск обучения, после каждой учебной итерации, и однажды после того, как обучение закончилось. Каждый раз выходные функции называются, trainNetwork передает структуру, содержащую информацию, такую как текущий номер итерации, потеря и точность. Можно использовать выходные функции, чтобы отобразить или построить информацию о прогрессе или остановить обучение. Чтобы остановить обучение рано, заставьте свою выходную функцию возвратить true. Если какая-либо выходная функция возвращает true, затем обучение закончило and trainNetwork возвращает последнюю сеть.

Чтобы остановиться обучение, когда потеря на наборе валидации прекратит уменьшаться, просто задайте данные о валидации и терпение валидации с помощью 'ValidationData' и 'ValidationPatience' аргументы пары "имя-значение" trainingOptions, соответственно. Терпение валидации является числом раз, что потеря на наборе валидации может быть больше, чем или равняться ранее самой маленькой потере, прежде чем сетевое обучение остановится. Можно добавить дополнительный критерий остановки с помощью выходных функций. В этом примере показано, как создать выходную функцию, которая останавливает обучение, когда точность классификации на данных о валидации прекращает улучшаться. Выходная функция задана в конце скрипта.

Загрузите обучающие данные, который содержит 5 000 изображений цифр. Отложите 1000 из изображений для сетевой валидации.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Создайте сеть, чтобы классифицировать данные изображения цифры.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Задайте опции для сетевого обучения. Чтобы проверить сеть равномерно во время обучения, задайте данные о валидации. Выберите 'ValidationFrequency' значение так, чтобы сеть была проверена однажды в эпоху.

Чтобы остановить обучение, когда точность классификации на наборе валидации прекратит улучшаться, задайте stopIfAccuracyNotImproving как выходная функция. Второй входной параметр stopIfAccuracyNotImproving число раз, что точность на наборе валидации может быть меньшей, чем или равняться ранее самой высокой точности, прежде чем сетевое обучение остановится. Выберите любое большое значение для максимального количества эпох, чтобы обучаться. Обучение не должно достигать итоговой эпохи, потому что обучение останавливается автоматически.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

Обучите сеть. Обучение останавливается, когда точность валидации прекращает увеличиваться.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:06 |       71.09% |       74.70% |       0.8805 |       0.8120 |          0.0100 |
|       2 |          62 |       00:00:08 |       87.50% |       87.90% |       0.3866 |       0.4448 |          0.0100 |
|       3 |          93 |       00:00:11 |       94.53% |       94.30% |       0.2178 |       0.2529 |          0.0100 |
|       4 |         124 |       00:00:13 |       96.09% |       96.60% |       0.1433 |       0.1759 |          0.0100 |
|       5 |         155 |       00:00:15 |      100.00% |       97.40% |       0.0994 |       0.1306 |          0.0100 |
|       6 |         186 |       00:00:18 |       99.22% |       97.90% |       0.0786 |       0.1126 |          0.0100 |
|       7 |         217 |       00:00:20 |       99.22% |       98.20% |       0.0552 |       0.0938 |          0.0100 |
|       8 |         248 |       00:00:23 |      100.00% |       97.60% |       0.0429 |       0.0871 |          0.0100 |
|       9 |         279 |       00:00:26 |      100.00% |       98.00% |       0.0338 |       0.0777 |          0.0100 |
|      10 |         310 |       00:00:28 |      100.00% |       98.50% |       0.0271 |       0.0681 |          0.0100 |
|      11 |         341 |       00:00:31 |      100.00% |       98.20% |       0.0237 |       0.0623 |          0.0100 |
|      12 |         372 |       00:00:33 |      100.00% |       98.60% |       0.0212 |       0.0570 |          0.0100 |
|      13 |         403 |       00:00:36 |      100.00% |       98.70% |       0.0186 |       0.0533 |          0.0100 |
|      14 |         434 |       00:00:38 |      100.00% |       98.70% |       0.0163 |       0.0507 |          0.0100 |
|      15 |         465 |       00:00:41 |      100.00% |       98.80% |       0.0143 |       0.0483 |          0.0100 |
|      16 |         496 |       00:00:43 |      100.00% |       99.00% |       0.0127 |       0.0457 |          0.0100 |
|      17 |         527 |       00:00:46 |      100.00% |       98.90% |       0.0113 |       0.0435 |          0.0100 |
|      18 |         558 |       00:00:48 |      100.00% |       99.00% |       0.0102 |       0.0416 |          0.0100 |
|      19 |         589 |       00:00:51 |      100.00% |       99.10% |       0.0093 |       0.0400 |          0.0100 |
|      20 |         620 |       00:00:53 |      100.00% |       99.10% |       0.0086 |       0.0387 |          0.0100 |
|      21 |         651 |       00:00:56 |      100.00% |       99.20% |       0.0081 |       0.0375 |          0.0100 |
|      22 |         682 |       00:00:58 |      100.00% |       99.10% |       0.0076 |       0.0364 |          0.0100 |
|      23 |         713 |       00:01:01 |      100.00% |       99.10% |       0.0073 |       0.0355 |          0.0100 |
|      24 |         744 |       00:01:03 |      100.00% |       99.10% |       0.0069 |       0.0346 |          0.0100 |
|======================================================================================================================|

Задайте выходную функцию

Задайте выходную функцию stopIfAccuracyNotImproving(info,N), который останавливает сетевое обучение, если лучшая точность классификации на данных о валидации не улучшается для N сетевые валидации подряд. Этот критерий похож на встроенный критерий остановки с помощью потери валидации, за исключением того, что это применяется к точности классификации вместо потери.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

Смотрите также

|

Похожие темы