В этом примере показано, как задать выходную функцию, которая запускается в каждой итерации во время обучения глубоких нейронных сетей. Если вы задаете выходные функции при помощи 'OutputFcn'
аргумент пары "имя-значение" trainingOptions
, затем trainNetwork
вызывает эти функции однажды запуск обучения, после каждой учебной итерации, и однажды после того, как обучение закончилось. Каждый раз выходные функции называются, trainNetwork
передает структуру, содержащую информацию, такую как текущий номер итерации, потеря и точность. Можно использовать выходные функции, чтобы отобразить или построить информацию о прогрессе или остановить обучение. Чтобы остановить обучение рано, заставьте свою выходную функцию возвратить true
. Если какая-либо выходная функция возвращает true
, затем обучение закончило and trainNetwork
возвращает последнюю сеть.
Чтобы остановиться обучение, когда потеря на наборе валидации прекратит уменьшаться, просто задайте данные о валидации и терпение валидации с помощью 'ValidationData'
и 'ValidationPatience'
аргументы пары "имя-значение" trainingOptions
, соответственно. Терпение валидации является числом раз, что потеря на наборе валидации может быть больше, чем или равняться ранее самой маленькой потере, прежде чем сетевое обучение остановится. Можно добавить дополнительный критерий остановки с помощью выходных функций. В этом примере показано, как создать выходную функцию, которая останавливает обучение, когда точность классификации на данных о валидации прекращает улучшаться. Выходная функция задана в конце скрипта.
Загрузите обучающие данные, который содержит 5 000 изображений цифр. Отложите 1000 из изображений для сетевой валидации.
[XTrain,YTrain] = digitTrain4DArrayData; idx = randperm(size(XTrain,4),1000); XValidation = XTrain(:,:,:,idx); XTrain(:,:,:,idx) = []; YValidation = YTrain(idx); YTrain(idx) = [];
Создайте сеть, чтобы классифицировать данные изображения цифры.
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];
Задайте опции для сетевого обучения. Чтобы подтвердить сеть равномерно во время обучения, задайте данные о валидации. Выберите 'ValidationFrequency'
значение так, чтобы сеть была подтверждена однажды в эпоху.
Чтобы остановить обучение, когда точность классификации на наборе валидации прекратит улучшаться, задайте stopIfAccuracyNotImproving
как выходная функция. Второй входной параметр stopIfAccuracyNotImproving
число раз, что точность на наборе валидации может быть меньшей, чем или равняться ранее самой высокой точности, прежде чем сетевое обучение остановится. Выберите любое большое значение для максимального количества эпох, чтобы обучаться. Обучение не должно достигать итоговой эпохи, потому что обучение останавливается автоматически.
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',100, ... 'MiniBatchSize',miniBatchSize, ... 'VerboseFrequency',validationFrequency, ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));
Обучите сеть. Обучение останавливается, когда точность валидации прекращает увеличиваться.
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single GPU.
Initializing image normalization. |======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================|
| 1 | 1 | 00:00:00 | 5.47% | 14.10% | 2.3462 | 2.3031 | 0.0100 |
| 1 | 31 | 00:00:01 | 87.50% | 84.60% | 0.4605 | 0.4979 | 0.0100 |
| 2 | 62 | 00:00:03 | 95.31% | 95.40% | 0.1853 | 0.1946 | 0.0100 |
| 3 | 93 | 00:00:04 | 98.44% | 98.20% | 0.0926 | 0.1090 | 0.0100 |
| 4 | 124 | 00:00:05 | 99.22% | 98.80% | 0.0636 | 0.0808 | 0.0100 |
| 5 | 155 | 00:00:06 | 99.22% | 99.40% | 0.0412 | 0.0512 | 0.0100 |
| 6 | 186 | 00:00:07 | 99.22% | 99.80% | 0.0355 | 0.0374 | 0.0100 |
| 7 | 217 | 00:00:08 | 99.22% | 99.70% | 0.0269 | 0.0331 | 0.0100 |
| 8 | 248 | 00:00:09 | 100.00% | 99.80% | 0.0156 | 0.0269 | 0.0100 |
| 9 | 279 | 00:00:10 | 100.00% | 99.80% | 0.0149 | 0.0240 | 0.0100 |
|======================================================================================================================|
Задайте выходную функцию stopIfAccuracyNotImproving(info,N)
,
который останавливает сетевое обучение, если лучшая точность классификации на данных о валидации не улучшается для N
сетевые валидации подряд. Этот критерий похож на встроенный критерий остановки с помощью потери валидации, за исключением того, что это применяется к точности классификации вместо потери.
function stop = stopIfAccuracyNotImproving(info,N) stop = false; % Keep track of the best validation accuracy and the number of validations for which % there has not been an improvement of the accuracy. persistent bestValAccuracy persistent valLag % Clear the variables when training starts. if info.State == "start" bestValAccuracy = 0; valLag = 0; elseif ~isempty(info.ValidationLoss) % Compare the current validation accuracy to the best accuracy so far, % and either set the best accuracy to the current accuracy, or increase % the number of validations for which there has not been an improvement. if info.ValidationAccuracy > bestValAccuracy valLag = 0; bestValAccuracy = info.ValidationAccuracy; else valLag = valLag + 1; end % If the validation lag is at least N, that is, the validation accuracy % has not improved for at least N validations, then return true and % stop training. if valLag >= N stop = true; end end end
trainNetwork
| trainingOptions