В этом примере показано, как определить функцию вывода, которая выполняется на каждой итерации во время обучения нейронных сетей глубокого обучения. При указании функций вывода с помощью 'OutputFcn' аргумент пары имя-значение trainingOptions, то trainNetwork вызывает эти функции один раз до начала обучения, после каждой итерации обучения и один раз после окончания обучения. Каждый раз при вызове функций вывода: trainNetwork передает структуру, содержащую такую информацию, как текущее число итераций, потери и точность. Функции вывода можно использовать для просмотра или печати информации о ходе выполнения или для прекращения обучения. Чтобы остановить обучение раньше, вернитесь к функции вывода true. Если возвращается какая-либо выходная функция true, затем обучение заканчивает и trainNetwork возвращает последнюю сеть.
Чтобы прекратить обучение, когда потеря в наборе проверки перестает уменьшаться, просто укажите данные проверки и терпение проверки с помощью 'ValidationData' и 'ValidationPatience' аргументы пары имя-значение trainingOptionsсоответственно. Терпение проверки - это количество раз, когда потери в наборе проверки могут быть больше или равны ранее наименьшим потерям перед остановкой обучения сети. С помощью функций вывода можно добавить дополнительные критерии остановки. В этом примере показано, как создать функцию вывода, которая прекращает обучение, когда точность классификации в данных проверки перестает улучшаться. Функция вывода определяется в конце сценария.
Загрузите обучающие данные, содержащие 5000 изображений цифр. Отложите 1000 изображений для проверки сети.
[XTrain,YTrain] = digitTrain4DArrayData; idx = randperm(size(XTrain,4),1000); XValidation = XTrain(:,:,:,idx); XTrain(:,:,:,idx) = []; YValidation = YTrain(idx); YTrain(idx) = [];
Создайте сеть для классификации цифровых данных изображения.
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
Укажите параметры сетевого обучения. Для проверки сети с регулярными интервалами во время обучения укажите данные проверки. Выберите 'ValidationFrequency' значение, чтобы сеть проверялась один раз в эпоху.
Чтобы прекратить обучение, когда точность классификации в наборе проверки перестает улучшаться, укажите stopIfAccuracyNotImproving в качестве функции вывода. Второй входной аргумент stopIfAccuracyNotImproving количество раз, когда точность в наборе проверки может быть меньше или равна ранее наивысшей точности перед остановкой сетевого обучения. Выберите любое большое значение максимального количества эпох для обучения. Обучение не должно дойти до финальной эпохи, потому что обучение прекращается автоматически.
miniBatchSize = 128; validationFrequency = floor(numel(YTrain)/miniBatchSize); options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',100, ... 'MiniBatchSize',miniBatchSize, ... 'VerboseFrequency',validationFrequency, ... 'ValidationData',{XValidation,YValidation}, ... 'ValidationFrequency',validationFrequency, ... 'Plots','training-progress', ... 'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));
Обучение сети. Обучение прекращается, когда точность проверки перестает увеличиваться.
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU. Initializing input data normalization.
|======================================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning | | | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate | |======================================================================================================================|
| 1 | 1 | 00:00:03 | 7.81% | 12.70% | 2.7155 | 2.5169 | 0.0100 |
| 1 | 31 | 00:00:06 | 71.09% | 74.70% | 0.8805 | 0.8120 | 0.0100 |
| 2 | 62 | 00:00:08 | 87.50% | 87.90% | 0.3866 | 0.4448 | 0.0100 |
| 3 | 93 | 00:00:11 | 94.53% | 94.30% | 0.2178 | 0.2529 | 0.0100 |
| 4 | 124 | 00:00:13 | 96.09% | 96.60% | 0.1433 | 0.1759 | 0.0100 |
| 5 | 155 | 00:00:15 | 100.00% | 97.40% | 0.0994 | 0.1306 | 0.0100 |
| 6 | 186 | 00:00:18 | 99.22% | 97.90% | 0.0786 | 0.1126 | 0.0100 |
| 7 | 217 | 00:00:20 | 99.22% | 98.20% | 0.0552 | 0.0938 | 0.0100 |
| 8 | 248 | 00:00:23 | 100.00% | 97.60% | 0.0429 | 0.0871 | 0.0100 |
| 9 | 279 | 00:00:26 | 100.00% | 98.00% | 0.0338 | 0.0777 | 0.0100 |
| 10 | 310 | 00:00:28 | 100.00% | 98.50% | 0.0271 | 0.0681 | 0.0100 |
| 11 | 341 | 00:00:31 | 100.00% | 98.20% | 0.0237 | 0.0623 | 0.0100 |
| 12 | 372 | 00:00:33 | 100.00% | 98.60% | 0.0212 | 0.0570 | 0.0100 |
| 13 | 403 | 00:00:36 | 100.00% | 98.70% | 0.0186 | 0.0533 | 0.0100 |
| 14 | 434 | 00:00:38 | 100.00% | 98.70% | 0.0163 | 0.0507 | 0.0100 |
| 15 | 465 | 00:00:41 | 100.00% | 98.80% | 0.0143 | 0.0483 | 0.0100 |
| 16 | 496 | 00:00:43 | 100.00% | 99.00% | 0.0127 | 0.0457 | 0.0100 |
| 17 | 527 | 00:00:46 | 100.00% | 98.90% | 0.0113 | 0.0435 | 0.0100 |
| 18 | 558 | 00:00:48 | 100.00% | 99.00% | 0.0102 | 0.0416 | 0.0100 |
| 19 | 589 | 00:00:51 | 100.00% | 99.10% | 0.0093 | 0.0400 | 0.0100 |
| 20 | 620 | 00:00:53 | 100.00% | 99.10% | 0.0086 | 0.0387 | 0.0100 |
| 21 | 651 | 00:00:56 | 100.00% | 99.20% | 0.0081 | 0.0375 | 0.0100 |
| 22 | 682 | 00:00:58 | 100.00% | 99.10% | 0.0076 | 0.0364 | 0.0100 |
| 23 | 713 | 00:01:01 | 100.00% | 99.10% | 0.0073 | 0.0355 | 0.0100 |
| 24 | 744 | 00:01:03 | 100.00% | 99.10% | 0.0069 | 0.0346 | 0.0100 |
|======================================================================================================================|

Определение функции вывода stopIfAccuracyNotImproving(info,N), который прекращает обучение сети, если наилучшая точность классификации в данных проверки не улучшается для N проверки сети в строке. Этот критерий аналогичен встроенному критерию остановки, использующему потери проверки, за исключением того, что он применяется к точности классификации вместо потерь.
function stop = stopIfAccuracyNotImproving(info,N) stop = false; % Keep track of the best validation accuracy and the number of validations for which % there has not been an improvement of the accuracy. persistent bestValAccuracy persistent valLag % Clear the variables when training starts. if info.State == "start" bestValAccuracy = 0; valLag = 0; elseif ~isempty(info.ValidationLoss) % Compare the current validation accuracy to the best accuracy so far, % and either set the best accuracy to the current accuracy, or increase % the number of validations for which there has not been an improvement. if info.ValidationAccuracy > bestValAccuracy valLag = 0; bestValAccuracy = info.ValidationAccuracy; else valLag = valLag + 1; end % If the validation lag is at least N, that is, the validation accuracy % has not improved for at least N validations, then return true and % stop training. if valLag >= N stop = true; end end end
trainingOptions | trainNetwork