Обучите сеть Используя циклический, изучают уровень для снимка состояния Ensembling

В этом примере показано, как обучить сеть, чтобы классифицировать изображения объектов с помощью циклического, изучают расписание уровня и создают снимки ensembling для лучшей тестовой точности. В примере вы изучаете, как использовать косинусную функцию в изучить расписании уровня, взять снимки состояния сети во время обучения создать ансамбль модели и добавить регуляризацию L2-нормы (затухание веса) к учебной потере.

Этот пример обучается, остаточная сеть [1] на наборе данных CIFAR-10 [2] с пользовательским циклическим изучают уровень: для каждой итерации решатель использует изучить уровень, данный переключенной косинусной функцией [3] alpha(t) = (alpha0/2)*cos(pi*mod(t-1,T/M)/(T/M)+1), где t номер итерации, T общее количество учебных итераций, alpha0 начальная буква, изучают уровень и M количество циклов/снимков состояния. Это узнает, что расписание уровня эффективно разделяет учебный процесс в M циклы. Каждый цикл начинается с большого темпа обучения, который затухает монотонно, обеспечивая сеть, чтобы исследовать различные локальные минимумы. В конце каждого учебного цикла вы берете снимок состояния сети (то есть, вы сохраняете модель в этой итерации), и более позднее среднее значение прогнозы всех моделей снимка состояния, также известных как ensembling [4] снимка состояния, чтобы улучшить точность завершающего испытания.

Подготовка данных

Загрузите набор данных CIFAR-10 [2]. Набор данных содержит 60 000 изображений. Каждое изображение находится 32 32 в размере и имеет три цветовых канала (RGB). Размер набора данных составляет 175 Мбайт. В зависимости от вашего интернет-соединения может занять время процесс загрузки.

datadir = tempdir; 
downloadCIFARData(datadir);

Загрузите обучение CIFAR-10 и протестируйте изображения как 4-D массивы. Набор обучающих данных содержит 50 000 изображений, и набор тестов содержит 10 000 изображений.

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);
classes = categories(YTrain);
numClasses = numel(classes);

Можно отобразить случайную выборку учебных изображений с помощью следующего кода.

figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),'ThumbnailSize',[96,96]);
imshow(im)

Создайте augmentedImageDatastore возразите, чтобы использовать в сетевом обучении. Во время обучения datastore случайным образом инвертирует учебные изображения вдоль вертикальной оси и случайным образом переводит их до четырех пикселей горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ...
    'DataAugmentation',imageAugmenter);
auimdsTest = augmentedImageDatastore(imageSize, XTest, YTest);

Архитектура сети Define

Создайте остаточную сеть [1] с шестью стандартными сверточными модулями (два модуля на этап) и ширина 16. Общая сетевая глубина 2*6+2 = 14. Кроме того, задайте среднее изображение с помощью 'Mean' опция в изображении ввела слой.

netWidth = 16;
layers = [
    imageInputLayer(imageSize,'Name','input','Mean', mean(XTrain,4))
    convolution2dLayer(3,netWidth,'Padding','same','Name','convInp')
    batchNormalizationLayer('Name','BNInp')
    reluLayer('Name','reluInp')
    
    convolutionalUnit(netWidth,1,'S1U1')
    additionLayer(2,'Name','add11')
    reluLayer('Name','relu11')
    convolutionalUnit(netWidth,1,'S1U2')
    additionLayer(2,'Name','add12')
    reluLayer('Name','relu12')
    
    convolutionalUnit(2*netWidth,2,'S2U1')
    additionLayer(2,'Name','add21')
    reluLayer('Name','relu21')
    convolutionalUnit(2*netWidth,1,'S2U2')
    additionLayer(2,'Name','add22')
    reluLayer('Name','relu22')
    
    convolutionalUnit(4*netWidth,2,'S3U1')
    additionLayer(2,'Name','add31')
    reluLayer('Name','relu31')
    convolutionalUnit(4*netWidth,1,'S3U2')
    additionLayer(2,'Name','add32')
    reluLayer('Name','relu32')
    
    averagePooling2dLayer(8,'Name','globalPool')
    fullyConnectedLayer(10,'Name','fcFinal')
    ];

lgraph = layerGraph(layers);
lgraph = connectLayers(lgraph,'reluInp','add11/in2');
lgraph = connectLayers(lgraph,'relu11','add12/in2');
skip1 = [
    convolution2dLayer(1,2*netWidth,'Stride',2,'Name','skipConv1')
    batchNormalizationLayer('Name','skipBN1')];
lgraph = addLayers(lgraph,skip1);
lgraph = connectLayers(lgraph,'relu12','skipConv1');
lgraph = connectLayers(lgraph,'skipBN1','add21/in2');
lgraph = connectLayers(lgraph,'relu21','add22/in2');
skip2 = [
    convolution2dLayer(1,4*netWidth,'Stride',2,'Name','skipConv2')
    batchNormalizationLayer('Name','skipBN2')];
lgraph = addLayers(lgraph,skip2);
lgraph = connectLayers(lgraph,'relu22','skipConv2');
lgraph = connectLayers(lgraph,'skipBN2','add31/in2');
lgraph = connectLayers(lgraph,'relu31','add32/in2');

Постройте архитектуру ResNet.

figure;
plot(lgraph)

Создайте dlnetwork объект из графика слоя.

dlnet = dlnetwork(lgraph);

Функция градиентов модели Define

Создайте функцию помощника modelGradients, перечисленный в конце примера. Функция берет в dlnetwork объект dlnet and мини-пакет входных данных dlX с соответствием маркирует Y, и возвращает градиенты потери относительно learnable параметров в dlnet. Эта функция также возвращает потерю и состояние nonlearnable параметров сети в данной итерации.

Задайте опции обучения

Задайте опции обучения.

velocity = [];
numEpochs = 200;
miniBatchSize = 64;
augimdsTrain.MiniBatchSize = miniBatchSize;
numObservations = numel(YTrain);
numIterationsPerEpoch = floor(numObservations./miniBatchSize);
momentum = 0.9;
weightDecay = 1e-4;

Укажите, что опции обучения, характерные для циклического, изучают уровень. Alpha0 начальная буква, изучают уровень и numSnapshots количество циклов или создает снимки взятый во время обучения.

alpha0 = 0.1;
numSnapshots = 5;
epochsPerSnapshot = numEpochs./numSnapshots; 
iterationsPerSnapshot = ceil(numObservations./miniBatchSize)*numEpochs./numSnapshots;
modelPrefix = "SnapshotEpoch";

Обучайтесь на графическом процессоре, если вы доступны (требует Parallel Computing Toolbox™).

executionEnvironment = "auto";

Инициализируйте учебную фигуру.

[lossLine, learnRateLine] = plotLossAndLearnRate();

Обучите модель

Обучите модель с помощью пользовательского учебного цикла.

В течение каждой эпохи переставьте datastore, цикл по мини-пакетам данных, и сохраните модель (снимок состояния), если текущая эпоха является кратной epochsPerSnapshot. В конце каждой эпохи отобразите прогресс обучения.

Для каждого мини-пакета:

  • Преобразуйте метки в фиктивные переменные.

  • Преобразуйте данные в dlarray объекты с базовым одним типом и указывают, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет).

  • Для обучения графического процессора преобразуйте мини-пакетные данные в gpuArray объекты.

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функция.

  • Обновите состояние nonlearnable параметров сети.

  • Решите, что изучить уровень для циклического изучает расписание уровня.

  • Обновите сетевые параметры с помощью sgdmupdate функция.

  • Постройте потерю и изучите уровень в каждой итерации.

В данном примере обучение взяло приблизительно 18-й на NVIDIA™ GeForce GTX 1080.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Reset image datastore.
    reset(augimdsTrain);
  
    % Shuffle data.
    augimdsTrain = shuffle(augimdsTrain);
    
    % Save snapshot model.
    if ~mod(epoch,epochsPerSnapshot)
        save(modelPrefix + epoch + ".mat",'dlnet');
    end
    
    % Loop over mini-batches.
    while hasdata(augimdsTrain)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimdsTrain);
        
        % Concatenate the inputs.
        Xdata = data{:,1};
        X = cat(4,Xdata{:});

        % Convert the labels to dummy variables.
        TrueClasses = data{:,2};
        Y = zeros(numClasses, numel(TrueClasses), 'single');
        for c = 1:numClasses
            Y(c,TrueClasses==classes(c)) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(single(X),'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and loss using dlfeval and the
        % modelGradients function.
        [gradients, loss, state] = dlfeval(@modelGradients,dlnet,dlX,Y,weightDecay);
        
        % Update the state of nonlearnable parameters.
        dlnet.State = state;
        
        % Determine learn rate for cyclical learn rate schedule.
        learnRate = 0.5*alpha0*(cos((pi*mod(iteration-1,iterationsPerSnapshot)./iterationsPerSnapshot))+1);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum);
        
        % Plot loss and learn rate for current iteration.
        loss = double(gather(extractdata(loss)));
        addpoints(lossLine, iteration, loss);
        addpoints(learnRateLine, iteration, learnRate);
        drawnow
        
    end
    
    % Display the training progress.
    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    disp( ...
        "Epoch: " + epoch + ", " + ...
        "Loss: " + num2str(loss) + ", " + ...
        "Elapsed: " + string(D))
end
Epoch: 1, Loss: 1.784, Elapsed: 00:05:59
Epoch: 2, Loss: 1.1388, Elapsed: 00:11:32
Epoch: 3, Loss: 1.0594, Elapsed: 00:17:02
Epoch: 4, Loss: 1.3154, Elapsed: 00:22:34
Epoch: 5, Loss: 0.7968, Elapsed: 00:28:08
Epoch: 6, Loss: 0.85319, Elapsed: 00:33:47
Epoch: 7, Loss: 0.5477, Elapsed: 00:39:26
Epoch: 8, Loss: 1.1158, Elapsed: 00:45:14
Epoch: 9, Loss: 0.39734, Elapsed: 00:50:47
Epoch: 10, Loss: 0.47909, Elapsed: 00:56:26
Epoch: 11, Loss: 0.68743, Elapsed: 01:02:08
Epoch: 12, Loss: 0.45266, Elapsed: 01:07:51
Epoch: 13, Loss: 1.0058, Elapsed: 01:13:27
Epoch: 14, Loss: 0.64816, Elapsed: 01:19:14
Epoch: 15, Loss: 1.1039, Elapsed: 01:25:05
Epoch: 16, Loss: 1.2072, Elapsed: 01:30:41
Epoch: 17, Loss: 0.43167, Elapsed: 01:36:24
Epoch: 18, Loss: 1.1376, Elapsed: 01:42:40
Epoch: 19, Loss: 0.76857, Elapsed: 01:48:40
Epoch: 20, Loss: 0.77434, Elapsed: 01:54:22
Epoch: 21, Loss: 0.98595, Elapsed: 01:59:57
Epoch: 22, Loss: 0.78628, Elapsed: 02:05:32
Epoch: 23, Loss: 0.55069, Elapsed: 02:11:07
Epoch: 24, Loss: 0.52066, Elapsed: 02:16:43
Epoch: 25, Loss: 0.44842, Elapsed: 02:22:18
Epoch: 26, Loss: 0.40094, Elapsed: 02:27:55
Epoch: 27, Loss: 0.78839, Elapsed: 02:33:30
Epoch: 28, Loss: 0.47829, Elapsed: 02:39:05
Epoch: 29, Loss: 0.21833, Elapsed: 02:44:41
Epoch: 30, Loss: 0.5759, Elapsed: 02:50:16
Epoch: 31, Loss: 1.1089, Elapsed: 02:55:50
Epoch: 32, Loss: 0.37353, Elapsed: 03:01:25
Epoch: 33, Loss: 0.30851, Elapsed: 03:07:01
Epoch: 34, Loss: 0.34735, Elapsed: 03:12:36
Epoch: 35, Loss: 0.28772, Elapsed: 03:18:11
Epoch: 36, Loss: 0.31045, Elapsed: 03:23:47
Epoch: 37, Loss: 0.28555, Elapsed: 03:29:22
Epoch: 38, Loss: 0.897, Elapsed: 03:34:58
Epoch: 39, Loss: 0.69014, Elapsed: 03:40:33
Epoch: 40, Loss: 0.26282, Elapsed: 03:46:17
Epoch: 41, Loss: 1.0086, Elapsed: 03:51:53
Epoch: 42, Loss: 0.47303, Elapsed: 03:57:27
Epoch: 43, Loss: 1.3765, Elapsed: 04:03:02
Epoch: 44, Loss: 0.54884, Elapsed: 04:08:39
Epoch: 45, Loss: 0.38778, Elapsed: 04:14:14
Epoch: 46, Loss: 0.74121, Elapsed: 04:19:49
Epoch: 47, Loss: 0.78481, Elapsed: 04:25:25
Epoch: 48, Loss: 0.44624, Elapsed: 04:31:01
Epoch: 49, Loss: 0.81747, Elapsed: 04:36:38
Epoch: 50, Loss: 0.40319, Elapsed: 04:42:14
Epoch: 51, Loss: 0.87757, Elapsed: 04:47:51
Epoch: 52, Loss: 1.0567, Elapsed: 04:53:27
Epoch: 53, Loss: 0.29019, Elapsed: 04:59:03
Epoch: 54, Loss: 0.92056, Elapsed: 05:04:40
Epoch: 55, Loss: 0.45776, Elapsed: 05:10:16
Epoch: 56, Loss: 1.0265, Elapsed: 05:15:52
Epoch: 57, Loss: 0.55256, Elapsed: 05:21:29
Epoch: 58, Loss: 1.0822, Elapsed: 05:27:06
Epoch: 59, Loss: 0.78332, Elapsed: 05:32:44
Epoch: 60, Loss: 0.48247, Elapsed: 05:38:20
Epoch: 61, Loss: 0.86749, Elapsed: 05:43:58
Epoch: 62, Loss: 0.64667, Elapsed: 05:49:34
Epoch: 63, Loss: 0.64563, Elapsed: 05:55:10
Epoch: 64, Loss: 0.58239, Elapsed: 06:00:46
Epoch: 65, Loss: 0.29219, Elapsed: 06:06:23
Epoch: 66, Loss: 0.37627, Elapsed: 06:11:59
Epoch: 67, Loss: 0.34035, Elapsed: 06:17:35
Epoch: 68, Loss: 0.34809, Elapsed: 06:23:11
Epoch: 69, Loss: 0.61085, Elapsed: 06:28:47
Epoch: 70, Loss: 0.42018, Elapsed: 06:34:24
Epoch: 71, Loss: 0.3739, Elapsed: 06:40:00
Epoch: 72, Loss: 0.23083, Elapsed: 06:45:37
Epoch: 73, Loss: 0.21324, Elapsed: 06:51:14
Epoch: 74, Loss: 0.18931, Elapsed: 06:56:55
Epoch: 75, Loss: 0.88882, Elapsed: 07:02:31
Epoch: 76, Loss: 0.36844, Elapsed: 07:08:09
Epoch: 77, Loss: 0.76548, Elapsed: 07:13:46
Epoch: 78, Loss: 0.42548, Elapsed: 07:19:24
Epoch: 79, Loss: 0.29112, Elapsed: 07:25:01
Epoch: 80, Loss: 0.17333, Elapsed: 07:30:45
Epoch: 81, Loss: 0.50322, Elapsed: 07:36:22
Epoch: 82, Loss: 0.40387, Elapsed: 07:41:58
Epoch: 83, Loss: 0.3939, Elapsed: 07:47:34
Epoch: 84, Loss: 0.79005, Elapsed: 07:53:11
Epoch: 85, Loss: 0.51953, Elapsed: 07:58:46
Epoch: 86, Loss: 0.65925, Elapsed: 08:04:24
Epoch: 87, Loss: 0.49915, Elapsed: 08:10:01
Epoch: 88, Loss: 0.58721, Elapsed: 08:15:38
Epoch: 89, Loss: 0.57397, Elapsed: 08:21:15
Epoch: 90, Loss: 0.51315, Elapsed: 08:26:53
Epoch: 91, Loss: 0.42037, Elapsed: 08:32:30
Epoch: 92, Loss: 0.41111, Elapsed: 08:38:06
Epoch: 93, Loss: 0.71338, Elapsed: 08:43:43
Epoch: 94, Loss: 0.31452, Elapsed: 08:49:21
Epoch: 95, Loss: 0.35696, Elapsed: 08:54:58
Epoch: 96, Loss: 0.56142, Elapsed: 09:00:36
Epoch: 97, Loss: 0.69246, Elapsed: 09:06:15
Epoch: 98, Loss: 0.40288, Elapsed: 09:11:53
Epoch: 99, Loss: 0.67491, Elapsed: 09:17:31
Epoch: 100, Loss: 0.70555, Elapsed: 09:23:08
Epoch: 101, Loss: 0.45978, Elapsed: 09:28:47
Epoch: 102, Loss: 0.3963, Elapsed: 09:34:27
Epoch: 103, Loss: 0.60798, Elapsed: 09:40:05
Epoch: 104, Loss: 0.41759, Elapsed: 09:45:45
Epoch: 105, Loss: 0.45068, Elapsed: 09:51:23
Epoch: 106, Loss: 1.103, Elapsed: 09:57:02
Epoch: 107, Loss: 0.29916, Elapsed: 10:02:41
Epoch: 108, Loss: 0.64019, Elapsed: 10:08:21
Epoch: 109, Loss: 0.26558, Elapsed: 10:13:59
Epoch: 110, Loss: 0.41303, Elapsed: 10:19:38
Epoch: 111, Loss: 0.74221, Elapsed: 10:25:18
Epoch: 112, Loss: 0.48748, Elapsed: 10:30:56
Epoch: 113, Loss: 0.27348, Elapsed: 10:36:35
Epoch: 114, Loss: 0.51661, Elapsed: 10:42:14
Epoch: 115, Loss: 0.27831, Elapsed: 10:47:54
Epoch: 116, Loss: 0.35103, Elapsed: 10:53:33
Epoch: 117, Loss: 0.19571, Elapsed: 10:59:11
Epoch: 118, Loss: 0.37368, Elapsed: 11:04:50
Epoch: 119, Loss: 0.18644, Elapsed: 11:10:29
Epoch: 120, Loss: 0.48589, Elapsed: 11:16:16
Epoch: 121, Loss: 0.74257, Elapsed: 11:21:57
Epoch: 122, Loss: 0.65423, Elapsed: 11:27:37
Epoch: 123, Loss: 0.35185, Elapsed: 11:33:17
Epoch: 124, Loss: 0.81636, Elapsed: 11:38:55
Epoch: 125, Loss: 0.49292, Elapsed: 11:44:34
Epoch: 126, Loss: 0.9133, Elapsed: 11:50:14
Epoch: 127, Loss: 0.80498, Elapsed: 11:55:53
Epoch: 128, Loss: 0.59473, Elapsed: 12:01:33
Epoch: 129, Loss: 0.60313, Elapsed: 12:07:12
Epoch: 130, Loss: 0.5426, Elapsed: 12:12:50
Epoch: 131, Loss: 1.3471, Elapsed: 12:18:29
Epoch: 132, Loss: 0.35591, Elapsed: 12:24:08
Epoch: 133, Loss: 0.75186, Elapsed: 12:29:49
Epoch: 134, Loss: 0.98765, Elapsed: 12:35:29
Epoch: 135, Loss: 0.65345, Elapsed: 12:41:08
Epoch: 136, Loss: 0.78963, Elapsed: 12:46:48
Epoch: 137, Loss: 0.38269, Elapsed: 12:52:27
Epoch: 138, Loss: 0.5309, Elapsed: 12:58:06
Epoch: 139, Loss: 0.4119, Elapsed: 13:03:45
Epoch: 140, Loss: 0.93898, Elapsed: 13:09:26
Epoch: 141, Loss: 0.45791, Elapsed: 13:15:04
Epoch: 142, Loss: 0.70093, Elapsed: 13:20:43
Epoch: 143, Loss: 0.84997, Elapsed: 13:26:23
Epoch: 144, Loss: 0.27732, Elapsed: 13:32:05
Epoch: 145, Loss: 0.51171, Elapsed: 13:37:44
Epoch: 146, Loss: 0.81123, Elapsed: 13:43:24
Epoch: 147, Loss: 0.5678, Elapsed: 13:49:04
Epoch: 148, Loss: 0.58568, Elapsed: 13:54:44
Epoch: 149, Loss: 0.3952, Elapsed: 14:00:23
Epoch: 150, Loss: 0.31967, Elapsed: 14:06:03
Epoch: 151, Loss: 0.44051, Elapsed: 14:11:46
Epoch: 152, Loss: 0.99278, Elapsed: 14:17:27
Epoch: 153, Loss: 0.87306, Elapsed: 14:23:07
Epoch: 154, Loss: 0.34008, Elapsed: 14:28:47
Epoch: 155, Loss: 0.4687, Elapsed: 14:34:27
Epoch: 156, Loss: 0.22836, Elapsed: 14:40:07
Epoch: 157, Loss: 0.23204, Elapsed: 14:45:48
Epoch: 158, Loss: 0.36854, Elapsed: 14:51:28
Epoch: 159, Loss: 0.35363, Elapsed: 14:57:08
Epoch: 160, Loss: 0.37937, Elapsed: 15:02:55
Epoch: 161, Loss: 0.7725, Elapsed: 15:08:36
Epoch: 162, Loss: 0.59353, Elapsed: 15:14:15
Epoch: 163, Loss: 0.57963, Elapsed: 15:19:54
Epoch: 164, Loss: 0.54625, Elapsed: 15:25:35
Epoch: 165, Loss: 0.65612, Elapsed: 15:31:15
Epoch: 166, Loss: 0.73254, Elapsed: 15:36:56
Epoch: 167, Loss: 0.4483, Elapsed: 15:42:37
Epoch: 168, Loss: 0.36817, Elapsed: 15:48:17
Epoch: 169, Loss: 0.57539, Elapsed: 15:53:57
Epoch: 170, Loss: 1.0026, Elapsed: 15:59:37
Epoch: 171, Loss: 0.95288, Elapsed: 16:05:17
Epoch: 172, Loss: 0.83053, Elapsed: 16:10:59
Epoch: 173, Loss: 0.41976, Elapsed: 16:16:39
Epoch: 174, Loss: 0.44098, Elapsed: 16:22:19
Epoch: 175, Loss: 0.58823, Elapsed: 16:28:00
Epoch: 176, Loss: 0.67325, Elapsed: 16:33:41
Epoch: 177, Loss: 0.27045, Elapsed: 16:39:21
Epoch: 178, Loss: 0.66652, Elapsed: 16:45:02
Epoch: 179, Loss: 1.0097, Elapsed: 16:50:43
Epoch: 180, Loss: 0.40372, Elapsed: 16:56:23
Epoch: 181, Loss: 0.39175, Elapsed: 17:02:04
Epoch: 182, Loss: 0.40741, Elapsed: 17:07:45
Epoch: 183, Loss: 0.35398, Elapsed: 17:13:25
Epoch: 184, Loss: 0.63228, Elapsed: 17:19:05
Epoch: 185, Loss: 0.35308, Elapsed: 17:24:45
Epoch: 186, Loss: 0.46854, Elapsed: 17:30:27
Epoch: 187, Loss: 0.51346, Elapsed: 17:36:08
Epoch: 188, Loss: 0.71886, Elapsed: 17:41:48
Epoch: 189, Loss: 0.73986, Elapsed: 17:47:29
Epoch: 190, Loss: 0.46669, Elapsed: 17:53:10
Epoch: 191, Loss: 0.40962, Elapsed: 17:58:51
Epoch: 192, Loss: 0.25007, Elapsed: 18:04:31
Epoch: 193, Loss: 0.45651, Elapsed: 18:10:12
Epoch: 194, Loss: 0.20788, Elapsed: 18:15:52
Epoch: 195, Loss: 0.32097, Elapsed: 18:21:32
Epoch: 196, Loss: 0.28159, Elapsed: 18:27:15
Epoch: 197, Loss: 0.20396, Elapsed: 18:32:56
Epoch: 198, Loss: 0.30823, Elapsed: 18:38:37
Epoch: 199, Loss: 0.28583, Elapsed: 18:44:18

Epoch: 200, Loss: 0.32877, Elapsed: 18:50:07

Создайте ансамбль снимка состояния

Объедините снимки состояния M сети, взятой во время обучения сформировать итоговый ансамбль. Прогнозы ансамбля соответствуют среднему значению выхода полносвязного слоя из всех отдельных моделей M.

YPredictions = zeros(numClasses,numel(YTest),numSnapshots);
modelAccuracy = zeros(numSnapshots+1,1);
modelName = cell(numSnapshots+1,1);
for m = 1:numSnapshots
    modelName{m} = modelPrefix + m*epochsPerSnapshot;
    load(modelName{m} + ".mat");
    YPredictions(:,:,m) = gather(extractdata(predict(dlnet, dlarray(single(XTest),'SSCB'))));
    modelAccuracy(m) = computeAccuracy(YPredictions(:,:,m), YTest, classes);
    disp(modelName{m} + " accuracy: " + modelAccuracy(m) + "%")
end
SnapshotEpoch40 accuracy: 88.04%
SnapshotEpoch80 accuracy: 86.78%
SnapshotEpoch120 accuracy: 87.53%
SnapshotEpoch160 accuracy: 87.07%
SnapshotEpoch200 accuracy: 88.39%
modelAccuracy(end) = computeAccuracy(mean(YPredictions,3), YTest, classes);
modelName{end} = "Ensemble model";
disp("Ensemble accuracy: " + modelAccuracy(end) + "%")
Ensemble accuracy: 91.13%

Постройте точность

Постройте точность на наборе тестовых данных для всех моделей снимка состояния и модели ансамбля.

figure;bar(modelAccuracy);
ylabel('Accuracy (%)');
xticklabels(modelName)
xtickangle(45)
title('Model accuracy')

Функции помощника

Функция modelGradients

modelGradients функционируйте берет в dlnetwork объект dlnet, мини-пакет входных данных dlX, маркирует Y, и параметр для затухания веса. Функция возвращает градиенты, потерю и состояние nonlearnable параметров. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients, loss, state] = modelGradients(dlnet, dlX, Y, weightDecay)

[dlYPred, state] = forward(dlnet, dlX);
dlYPred = softmax(dlYPred);

loss = crossentropy(dlYPred, Y);

% L2-regularization (weight decay)
allParams = dlnet.Learnables(dlnet.Learnables.Parameter == "Weights" | dlnet.Learnables.Parameter == "Scale",:).Value;
l2Norm = cellfun(@(x) sum(x.^2,'All'), allParams, 'UniformOutput', false);
l2Norm = sum(cat(1, l2Norm{:}));
loss = loss + weightDecay*0.5*l2Norm;

gradients = dlgradient(loss, dlnet.Learnables);
end

Функция computeAccuracy

computeAccuracy функционируйте использует сетевые прогнозы, истинные метки и количество классов, чтобы вычислить точность.

function accuracy = computeAccuracy(YPredictions, YTest, classes)
[~,I] = max(YPredictions,[],1);
C = classes(I);
accuracy = 100*(sum(C==YTest)/numel(C));
end

Функция plotLossAndLearnRate

plotLossAndLearnRate графики функций потеря и изучают уровень в каждой итерации во время обучения.

function [lossLine, learnRateLine] = plotLossAndLearnRate()
figure('Name','Training Progress');
clf
subplot(2,1,1); lossLine = animatedline;
title('Loss');
xlabel('Iteration')
ylabel('Loss')
grid on
subplot(2,1,2); learnRateLine = animatedline;
title('Learning rate');
xlabel('Iteration')
ylabel('Learning rate')
grid on
end

Функция convolutionalUnit

convolutionalUnit(numF,stride,tag) создает массив слоев с двумя сверточными слоями и соответствующей пакетной нормализацией и слоями ReLU. numF количество сверточных фильтров, stride шаг первого сверточного слоя и tag тег, который предварительно ожидается ко всем именам слоя.

function layers = convolutionalUnit(numF,stride,tag)
layers = [
    convolution2dLayer(3,numF,'Padding','same','Stride',stride,'Name',[tag,'conv1'])
    batchNormalizationLayer('Name',[tag,'BN1'])
    reluLayer('Name',[tag,'relu1'])
    convolution2dLayer(3,numF,'Padding','same','Name',[tag,'conv2'])
    batchNormalizationLayer('Name',[tag,'BN2'])];
end

Ссылки

[1] Он, Kaiming, Сянюй Чжан, Шаоцин Жэнь и Цзянь Сунь. "Глубокая невязка, учащаяся для распознавания изображений". В Продолжениях конференции по IEEE по компьютерному зрению и распознаванию образов, стр 770-778. 2016.

[2] Krizhevsky, Алекс. "Изучая несколько слоев функций от крошечных изображений". (2009). https://www.cs.toronto.edu / ~ kriz/learning-features-2009-TR.pdf

[3] Лощилов, Илья и Франк Хуттер. "Sgdr: Стохастический градиентный спуск с горячими перезапусками". (2016). arXiv предварительно распечатывают arXiv:1608.03983.

[4] Хуан, Гао, Йиксуэн Ли, Джефф Плейсс, Чжуан Лю, Джон Э. Хопкрофт и Килиан К. Вайнбергер. "Создайте снимки ансамбли: Обучите 1, получите m бесплатно". (2017). arXiv предварительно распечатывают arXiv:1704.00109.

Смотрите также

| | | | | |

Похожие темы