Настройка Выхода во время Нейронной сети для глубокого обучения обучения

В этом примере показов, как задать выходную функцию, которая запускается при каждой итерации во время обучения глубоких нейронных сетей. Если вы задаете выходные функции при помощи 'OutputFcn' Аргумент пары "имя-значение" из trainingOptions, затем trainNetwork вызывает эти функции один раз до начала обучения, после каждой итерации обучения и один раз после завершения обучения. Каждый раз, когда вызываются выходные функции, trainNetwork передает структуру, содержащую информацию, такую как текущее число итераций, потери и точность. Можно использовать выходные функции для отображения или построения графика информации о прогрессе или для остановки обучения. Чтобы остановить обучение раньше, заставьте вашу выходную функцию вернуться true. Если какая-либо выходная функция возвращается true, затем обучение концов и trainNetworkвозвращает последнюю сеть.

Чтобы остановить обучение, когда потеря на наборе валидации перестает уменьшаться, просто укажите данные валидации и терпение валидации, используя 'ValidationData' и 'ValidationPatience' Аргументы пары "имя-значение" из trainingOptions, соответственно. Терпение валидации - это количество раз, когда потеря на наборе валидации может быть больше или равной ранее наименьшей потере до остановки сетевого обучения. Вы можете добавить дополнительные критерии остановки с помощью выходных функций. В этом примере показано, как создать выходную функцию, которая останавливает обучение, когда точность классификации данных валидации перестает улучшаться. Выходная функция определяется в конце скрипта.

Загрузите обучающие данные, которые содержат 5000 изображений цифр. Выделите 1000 изображений для валидации сети.

[XTrain,YTrain] = digitTrain4DArrayData;

idx = randperm(size(XTrain,4),1000);
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];

Создайте сеть, чтобы классифицировать цифровые данные изображения.

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer   
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

Укажите опции сетевого обучения. Чтобы проверить сеть на регулярных интервалах во время обучения, укажите данные валидации. Выберите 'ValidationFrequency' значение таким образом, чтобы сеть проверялась один раз в эпоху.

Чтобы остановить обучение, когда точность классификации на наборе валидации перестает улучшаться, задайте stopIfAccuracyNotImproving как выходная функция. Второй входной параметр stopIfAccuracyNotImproving количество раз, когда точность на наборе валидации может быть меньше или равной ранее самой высокой точности, прежде чем обучение сети остановится. Выберите любое большое значение для максимального количества эпох для обучения. Обучение не должно достигать последней эпохи, потому что обучение останавливается автоматически.

miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));

Обучите сеть. Обучение останавливается, когда точность валидации перестает увеличиваться.

net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |        7.81% |       12.70% |       2.7155 |       2.5169 |          0.0100 |
|       1 |          31 |       00:00:06 |       71.09% |       74.70% |       0.8805 |       0.8120 |          0.0100 |
|       2 |          62 |       00:00:08 |       87.50% |       87.90% |       0.3866 |       0.4448 |          0.0100 |
|       3 |          93 |       00:00:11 |       94.53% |       94.30% |       0.2178 |       0.2529 |          0.0100 |
|       4 |         124 |       00:00:13 |       96.09% |       96.60% |       0.1433 |       0.1759 |          0.0100 |
|       5 |         155 |       00:00:15 |      100.00% |       97.40% |       0.0994 |       0.1306 |          0.0100 |
|       6 |         186 |       00:00:18 |       99.22% |       97.90% |       0.0786 |       0.1126 |          0.0100 |
|       7 |         217 |       00:00:20 |       99.22% |       98.20% |       0.0552 |       0.0938 |          0.0100 |
|       8 |         248 |       00:00:23 |      100.00% |       97.60% |       0.0429 |       0.0871 |          0.0100 |
|       9 |         279 |       00:00:26 |      100.00% |       98.00% |       0.0338 |       0.0777 |          0.0100 |
|      10 |         310 |       00:00:28 |      100.00% |       98.50% |       0.0271 |       0.0681 |          0.0100 |
|      11 |         341 |       00:00:31 |      100.00% |       98.20% |       0.0237 |       0.0623 |          0.0100 |
|      12 |         372 |       00:00:33 |      100.00% |       98.60% |       0.0212 |       0.0570 |          0.0100 |
|      13 |         403 |       00:00:36 |      100.00% |       98.70% |       0.0186 |       0.0533 |          0.0100 |
|      14 |         434 |       00:00:38 |      100.00% |       98.70% |       0.0163 |       0.0507 |          0.0100 |
|      15 |         465 |       00:00:41 |      100.00% |       98.80% |       0.0143 |       0.0483 |          0.0100 |
|      16 |         496 |       00:00:43 |      100.00% |       99.00% |       0.0127 |       0.0457 |          0.0100 |
|      17 |         527 |       00:00:46 |      100.00% |       98.90% |       0.0113 |       0.0435 |          0.0100 |
|      18 |         558 |       00:00:48 |      100.00% |       99.00% |       0.0102 |       0.0416 |          0.0100 |
|      19 |         589 |       00:00:51 |      100.00% |       99.10% |       0.0093 |       0.0400 |          0.0100 |
|      20 |         620 |       00:00:53 |      100.00% |       99.10% |       0.0086 |       0.0387 |          0.0100 |
|      21 |         651 |       00:00:56 |      100.00% |       99.20% |       0.0081 |       0.0375 |          0.0100 |
|      22 |         682 |       00:00:58 |      100.00% |       99.10% |       0.0076 |       0.0364 |          0.0100 |
|      23 |         713 |       00:01:01 |      100.00% |       99.10% |       0.0073 |       0.0355 |          0.0100 |
|      24 |         744 |       00:01:03 |      100.00% |       99.10% |       0.0069 |       0.0346 |          0.0100 |
|======================================================================================================================|

Задайте выходную функцию

Определите выходную функцию stopIfAccuracyNotImproving(info,N), который останавливает сетевое обучение, если лучшая точность классификации на данных валидации не улучшается для N валидации сети в строке. Этот критерий аналогичен встроенному критерию остановки с помощью потерь валидации, за исключением того, что он применяется к точности классификации вместо потерь.

function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end

end

См. также

|

Похожие темы