В этом примере показано, как обучить сеть, чтобы классифицировать изображения объектов с помощью циклического, изучают расписание уровня и создают снимки ensembling для лучшей тестовой точности. В примере вы изучаете, как использовать косинусную функцию в изучить расписании уровня, взять снимки состояния сети во время обучения создать ансамбль модели и добавить регуляризацию L2-нормы (затухание веса) к учебной потере.
Этот пример обучается, остаточная сеть [1] на наборе данных CIFAR-10 [2] с пользовательским циклическим изучают уровень: для каждой итерации решатель использует изучить уровень, данный переключенной косинусной функцией [3] alpha(t) = (alpha0/2)*cos(pi*mod(t-1,T/M)/(T/M)+1)
, где t
номер итерации, T
общее количество учебных итераций, alpha0
начальная буква, изучают уровень и M
количество циклов/снимков состояния. Это узнает, что расписание уровня эффективно разделяет учебный процесс в M
циклы. Каждый цикл начинается с большой скорости обучения, которая затухает монотонно, обеспечивая сеть, чтобы исследовать различные локальные минимумы. В конце каждого учебного цикла вы берете снимок состояния сети (то есть, вы сохраняете модель в этой итерации), и более позднее среднее значение предсказания всех моделей снимка состояния, также известных как ensembling [4] снимка состояния, чтобы улучшить точность завершающего испытания.
Загрузите набор данных CIFAR-10 [2]. Набор данных содержит 60 000 изображений. Каждое изображение находится 32 32 в размере и имеет три цветовых канала (RGB). Размер набора данных составляет 175 Мбайт. В зависимости от вашего интернет-соединения может занять время процесс загрузки.
datadir = tempdir; downloadCIFARData(datadir);
Загрузите обучение CIFAR-10 и протестируйте изображения как 4-D массивы. Набор обучающих данных содержит 50 000 изображений, и набор тестов содержит 10 000 изображений.
[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir); classes = categories(YTrain); numClasses = numel(classes);
Можно отобразить случайную выборку учебных изображений с помощью следующего кода.
figure;
idx = randperm(size(XTrain,4),20);
im = imtile(XTrain(:,:,:,idx),'ThumbnailSize',[96,96]);
imshow(im)
Создайте augmentedImageDatastore
возразите, чтобы использовать в сетевом обучении. Во время обучения datastore случайным образом инвертирует учебные изображения вдоль вертикальной оси и случайным образом переводит их до четырех пикселей горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augimdsTrain = augmentedImageDatastore(imageSize,XTrain,YTrain, ... 'DataAugmentation',imageAugmenter); auimdsTest = augmentedImageDatastore(imageSize, XTest, YTest);
Создайте остаточную сеть [1] с шестью стандартными сверточными модулями (два модуля на этап) и ширина 16. Общая сетевая глубина 2*6+2 = 14. Кроме того, задайте среднее изображение с помощью 'Mean'
опция в изображении ввела слой.
netWidth = 16; layers = [ imageInputLayer(imageSize,'Name','input','Mean', mean(XTrain,4)) convolution2dLayer(3,netWidth,'Padding','same','Name','convInp') batchNormalizationLayer('Name','BNInp') reluLayer('Name','reluInp') convolutionalUnit(netWidth,1,'S1U1') additionLayer(2,'Name','add11') reluLayer('Name','relu11') convolutionalUnit(netWidth,1,'S1U2') additionLayer(2,'Name','add12') reluLayer('Name','relu12') convolutionalUnit(2*netWidth,2,'S2U1') additionLayer(2,'Name','add21') reluLayer('Name','relu21') convolutionalUnit(2*netWidth,1,'S2U2') additionLayer(2,'Name','add22') reluLayer('Name','relu22') convolutionalUnit(4*netWidth,2,'S3U1') additionLayer(2,'Name','add31') reluLayer('Name','relu31') convolutionalUnit(4*netWidth,1,'S3U2') additionLayer(2,'Name','add32') reluLayer('Name','relu32') averagePooling2dLayer(8,'Name','globalPool') fullyConnectedLayer(10,'Name','fcFinal') ]; lgraph = layerGraph(layers); lgraph = connectLayers(lgraph,'reluInp','add11/in2'); lgraph = connectLayers(lgraph,'relu11','add12/in2'); skip1 = [ convolution2dLayer(1,2*netWidth,'Stride',2,'Name','skipConv1') batchNormalizationLayer('Name','skipBN1')]; lgraph = addLayers(lgraph,skip1); lgraph = connectLayers(lgraph,'relu12','skipConv1'); lgraph = connectLayers(lgraph,'skipBN1','add21/in2'); lgraph = connectLayers(lgraph,'relu21','add22/in2'); skip2 = [ convolution2dLayer(1,4*netWidth,'Stride',2,'Name','skipConv2') batchNormalizationLayer('Name','skipBN2')]; lgraph = addLayers(lgraph,skip2); lgraph = connectLayers(lgraph,'relu22','skipConv2'); lgraph = connectLayers(lgraph,'skipBN2','add31/in2'); lgraph = connectLayers(lgraph,'relu31','add32/in2');
Постройте архитектуру ResNet.
figure; plot(lgraph)
Создайте dlnetwork
объект от графика слоев.
dlnet = dlnetwork(lgraph);
Создайте функцию помощника modelGradients
, перечисленный в конце примера. Функция берет в dlnetwork
объект dlnet
and
мини-пакет входных данных dlX
с соответствием маркирует Y,
и возвращает градиенты потери относительно настраиваемых параметров в dlnet
. Эта функция также возвращает потерю и состояние nonlearnable параметров сети в данной итерации.
Задайте опции обучения.
velocity = []; numEpochs = 200; miniBatchSize = 64; augimdsTrain.MiniBatchSize = miniBatchSize; numObservations = numel(YTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize); momentum = 0.9; weightDecay = 1e-4;
Укажите, что опции обучения, характерные для циклического, изучают уровень. Alpha0
начальная буква, изучают уровень и numSnapshots
количество циклов или создает снимки взятый во время обучения.
alpha0 = 0.1;
numSnapshots = 5;
epochsPerSnapshot = numEpochs./numSnapshots;
iterationsPerSnapshot = ceil(numObservations./miniBatchSize)*numEpochs./numSnapshots;
modelPrefix = "SnapshotEpoch";
Обучайтесь на графическом процессоре, если вы доступны (требует Parallel Computing Toolbox™).
executionEnvironment = "auto";
Инициализируйте учебную фигуру.
[lossLine, learnRateLine] = plotLossAndLearnRate();
Обучите модель с помощью пользовательского учебного цикла.
В течение каждой эпохи переставьте datastore, цикл по мини-пакетам данных, и сохраните модель (снимок состояния), если текущая эпоха является кратной epochsPerSnapshot
. В конце каждой эпохи отобразите прогресс обучения.
Для каждого мини-пакета:
Преобразуйте метки в фиктивные переменные.
Преобразуйте данные в dlarray
объекты с базовым одним типом и указывают, что размерность маркирует 'SSCB'
(пространственный, пространственный, канал, пакет).
Для обучения графического процессора преобразуйте мини-пакетные данные в gpuArray
объекты.
Оцените градиенты модели и потерю с помощью dlfeval
и modelGradients
функция.
Обновите состояние nonlearnable параметров сети.
Решите, что изучить уровень для циклического изучает расписание уровня.
Обновите сетевые параметры с помощью sgdmupdate
функция.
Постройте потерю и изучите уровень в каждой итерации.
В данном примере обучение взяло приблизительно 18-й на NVIDIA™ GeForce GTX 1080.
iteration = 0; start = tic; % Loop over epochs. for epoch = 1:numEpochs % Reset image datastore. reset(augimdsTrain); % Shuffle data. augimdsTrain = shuffle(augimdsTrain); % Save snapshot model. if ~mod(epoch,epochsPerSnapshot) save(modelPrefix + epoch + ".mat",'dlnet'); end % Loop over mini-batches. while hasdata(augimdsTrain) iteration = iteration + 1; % Read mini-batch of data. data = read(augimdsTrain); % Concatenate the inputs. Xdata = data{:,1}; X = cat(4,Xdata{:}); % Convert the labels to dummy variables. TrueClasses = data{:,2}; Y = zeros(numClasses, numel(TrueClasses), 'single'); for c = 1:numClasses Y(c,TrueClasses==classes(c)) = 1; end % Convert mini-batch of data to dlarray. dlX = dlarray(single(X),'SSCB'); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients, loss, state] = dlfeval(@modelGradients,dlnet,dlX,Y,weightDecay); % Update the state of nonlearnable parameters. dlnet.State = state; % Determine learn rate for cyclical learn rate schedule. learnRate = 0.5*alpha0*(cos((pi*mod(iteration-1,iterationsPerSnapshot)./iterationsPerSnapshot))+1); % Update the network parameters using the SGDM optimizer. [dlnet.Learnables, velocity] = sgdmupdate(dlnet.Learnables, gradients, velocity, learnRate, momentum); % Plot loss and learn rate for current iteration. loss = double(gather(extractdata(loss))); addpoints(lossLine, iteration, loss); addpoints(learnRateLine, iteration, learnRate); drawnow end % Display the training progress. D = duration(0,0,toc(start),'Format','hh:mm:ss'); disp( ... "Epoch: " + epoch + ", " + ... "Loss: " + num2str(loss) + ", " + ... "Elapsed: " + string(D)) end
Epoch: 1, Loss: 1.784, Elapsed: 00:05:59
Epoch: 2, Loss: 1.1388, Elapsed: 00:11:32
Epoch: 3, Loss: 1.0594, Elapsed: 00:17:02
Epoch: 4, Loss: 1.3154, Elapsed: 00:22:34
Epoch: 5, Loss: 0.7968, Elapsed: 00:28:08
Epoch: 6, Loss: 0.85319, Elapsed: 00:33:47
Epoch: 7, Loss: 0.5477, Elapsed: 00:39:26
Epoch: 8, Loss: 1.1158, Elapsed: 00:45:14
Epoch: 9, Loss: 0.39734, Elapsed: 00:50:47
Epoch: 10, Loss: 0.47909, Elapsed: 00:56:26
Epoch: 11, Loss: 0.68743, Elapsed: 01:02:08
Epoch: 12, Loss: 0.45266, Elapsed: 01:07:51
Epoch: 13, Loss: 1.0058, Elapsed: 01:13:27
Epoch: 14, Loss: 0.64816, Elapsed: 01:19:14
Epoch: 15, Loss: 1.1039, Elapsed: 01:25:05
Epoch: 16, Loss: 1.2072, Elapsed: 01:30:41
Epoch: 17, Loss: 0.43167, Elapsed: 01:36:24
Epoch: 18, Loss: 1.1376, Elapsed: 01:42:40
Epoch: 19, Loss: 0.76857, Elapsed: 01:48:40
Epoch: 20, Loss: 0.77434, Elapsed: 01:54:22
Epoch: 21, Loss: 0.98595, Elapsed: 01:59:57
Epoch: 22, Loss: 0.78628, Elapsed: 02:05:32
Epoch: 23, Loss: 0.55069, Elapsed: 02:11:07
Epoch: 24, Loss: 0.52066, Elapsed: 02:16:43
Epoch: 25, Loss: 0.44842, Elapsed: 02:22:18
Epoch: 26, Loss: 0.40094, Elapsed: 02:27:55
Epoch: 27, Loss: 0.78839, Elapsed: 02:33:30
Epoch: 28, Loss: 0.47829, Elapsed: 02:39:05
Epoch: 29, Loss: 0.21833, Elapsed: 02:44:41
Epoch: 30, Loss: 0.5759, Elapsed: 02:50:16
Epoch: 31, Loss: 1.1089, Elapsed: 02:55:50
Epoch: 32, Loss: 0.37353, Elapsed: 03:01:25
Epoch: 33, Loss: 0.30851, Elapsed: 03:07:01
Epoch: 34, Loss: 0.34735, Elapsed: 03:12:36
Epoch: 35, Loss: 0.28772, Elapsed: 03:18:11
Epoch: 36, Loss: 0.31045, Elapsed: 03:23:47
Epoch: 37, Loss: 0.28555, Elapsed: 03:29:22
Epoch: 38, Loss: 0.897, Elapsed: 03:34:58
Epoch: 39, Loss: 0.69014, Elapsed: 03:40:33
Epoch: 40, Loss: 0.26282, Elapsed: 03:46:17
Epoch: 41, Loss: 1.0086, Elapsed: 03:51:53
Epoch: 42, Loss: 0.47303, Elapsed: 03:57:27
Epoch: 43, Loss: 1.3765, Elapsed: 04:03:02
Epoch: 44, Loss: 0.54884, Elapsed: 04:08:39
Epoch: 45, Loss: 0.38778, Elapsed: 04:14:14
Epoch: 46, Loss: 0.74121, Elapsed: 04:19:49
Epoch: 47, Loss: 0.78481, Elapsed: 04:25:25
Epoch: 48, Loss: 0.44624, Elapsed: 04:31:01
Epoch: 49, Loss: 0.81747, Elapsed: 04:36:38
Epoch: 50, Loss: 0.40319, Elapsed: 04:42:14
Epoch: 51, Loss: 0.87757, Elapsed: 04:47:51
Epoch: 52, Loss: 1.0567, Elapsed: 04:53:27
Epoch: 53, Loss: 0.29019, Elapsed: 04:59:03
Epoch: 54, Loss: 0.92056, Elapsed: 05:04:40
Epoch: 55, Loss: 0.45776, Elapsed: 05:10:16
Epoch: 56, Loss: 1.0265, Elapsed: 05:15:52
Epoch: 57, Loss: 0.55256, Elapsed: 05:21:29
Epoch: 58, Loss: 1.0822, Elapsed: 05:27:06
Epoch: 59, Loss: 0.78332, Elapsed: 05:32:44
Epoch: 60, Loss: 0.48247, Elapsed: 05:38:20
Epoch: 61, Loss: 0.86749, Elapsed: 05:43:58
Epoch: 62, Loss: 0.64667, Elapsed: 05:49:34
Epoch: 63, Loss: 0.64563, Elapsed: 05:55:10
Epoch: 64, Loss: 0.58239, Elapsed: 06:00:46
Epoch: 65, Loss: 0.29219, Elapsed: 06:06:23
Epoch: 66, Loss: 0.37627, Elapsed: 06:11:59
Epoch: 67, Loss: 0.34035, Elapsed: 06:17:35
Epoch: 68, Loss: 0.34809, Elapsed: 06:23:11
Epoch: 69, Loss: 0.61085, Elapsed: 06:28:47
Epoch: 70, Loss: 0.42018, Elapsed: 06:34:24
Epoch: 71, Loss: 0.3739, Elapsed: 06:40:00
Epoch: 72, Loss: 0.23083, Elapsed: 06:45:37
Epoch: 73, Loss: 0.21324, Elapsed: 06:51:14
Epoch: 74, Loss: 0.18931, Elapsed: 06:56:55
Epoch: 75, Loss: 0.88882, Elapsed: 07:02:31
Epoch: 76, Loss: 0.36844, Elapsed: 07:08:09
Epoch: 77, Loss: 0.76548, Elapsed: 07:13:46
Epoch: 78, Loss: 0.42548, Elapsed: 07:19:24
Epoch: 79, Loss: 0.29112, Elapsed: 07:25:01
Epoch: 80, Loss: 0.17333, Elapsed: 07:30:45
Epoch: 81, Loss: 0.50322, Elapsed: 07:36:22
Epoch: 82, Loss: 0.40387, Elapsed: 07:41:58
Epoch: 83, Loss: 0.3939, Elapsed: 07:47:34
Epoch: 84, Loss: 0.79005, Elapsed: 07:53:11
Epoch: 85, Loss: 0.51953, Elapsed: 07:58:46
Epoch: 86, Loss: 0.65925, Elapsed: 08:04:24
Epoch: 87, Loss: 0.49915, Elapsed: 08:10:01
Epoch: 88, Loss: 0.58721, Elapsed: 08:15:38
Epoch: 89, Loss: 0.57397, Elapsed: 08:21:15
Epoch: 90, Loss: 0.51315, Elapsed: 08:26:53
Epoch: 91, Loss: 0.42037, Elapsed: 08:32:30
Epoch: 92, Loss: 0.41111, Elapsed: 08:38:06
Epoch: 93, Loss: 0.71338, Elapsed: 08:43:43
Epoch: 94, Loss: 0.31452, Elapsed: 08:49:21
Epoch: 95, Loss: 0.35696, Elapsed: 08:54:58
Epoch: 96, Loss: 0.56142, Elapsed: 09:00:36
Epoch: 97, Loss: 0.69246, Elapsed: 09:06:15
Epoch: 98, Loss: 0.40288, Elapsed: 09:11:53
Epoch: 99, Loss: 0.67491, Elapsed: 09:17:31
Epoch: 100, Loss: 0.70555, Elapsed: 09:23:08
Epoch: 101, Loss: 0.45978, Elapsed: 09:28:47
Epoch: 102, Loss: 0.3963, Elapsed: 09:34:27
Epoch: 103, Loss: 0.60798, Elapsed: 09:40:05
Epoch: 104, Loss: 0.41759, Elapsed: 09:45:45
Epoch: 105, Loss: 0.45068, Elapsed: 09:51:23
Epoch: 106, Loss: 1.103, Elapsed: 09:57:02
Epoch: 107, Loss: 0.29916, Elapsed: 10:02:41
Epoch: 108, Loss: 0.64019, Elapsed: 10:08:21
Epoch: 109, Loss: 0.26558, Elapsed: 10:13:59
Epoch: 110, Loss: 0.41303, Elapsed: 10:19:38
Epoch: 111, Loss: 0.74221, Elapsed: 10:25:18
Epoch: 112, Loss: 0.48748, Elapsed: 10:30:56
Epoch: 113, Loss: 0.27348, Elapsed: 10:36:35
Epoch: 114, Loss: 0.51661, Elapsed: 10:42:14
Epoch: 115, Loss: 0.27831, Elapsed: 10:47:54
Epoch: 116, Loss: 0.35103, Elapsed: 10:53:33
Epoch: 117, Loss: 0.19571, Elapsed: 10:59:11
Epoch: 118, Loss: 0.37368, Elapsed: 11:04:50
Epoch: 119, Loss: 0.18644, Elapsed: 11:10:29
Epoch: 120, Loss: 0.48589, Elapsed: 11:16:16
Epoch: 121, Loss: 0.74257, Elapsed: 11:21:57
Epoch: 122, Loss: 0.65423, Elapsed: 11:27:37
Epoch: 123, Loss: 0.35185, Elapsed: 11:33:17
Epoch: 124, Loss: 0.81636, Elapsed: 11:38:55
Epoch: 125, Loss: 0.49292, Elapsed: 11:44:34
Epoch: 126, Loss: 0.9133, Elapsed: 11:50:14
Epoch: 127, Loss: 0.80498, Elapsed: 11:55:53
Epoch: 128, Loss: 0.59473, Elapsed: 12:01:33
Epoch: 129, Loss: 0.60313, Elapsed: 12:07:12
Epoch: 130, Loss: 0.5426, Elapsed: 12:12:50
Epoch: 131, Loss: 1.3471, Elapsed: 12:18:29
Epoch: 132, Loss: 0.35591, Elapsed: 12:24:08
Epoch: 133, Loss: 0.75186, Elapsed: 12:29:49
Epoch: 134, Loss: 0.98765, Elapsed: 12:35:29
Epoch: 135, Loss: 0.65345, Elapsed: 12:41:08
Epoch: 136, Loss: 0.78963, Elapsed: 12:46:48
Epoch: 137, Loss: 0.38269, Elapsed: 12:52:27
Epoch: 138, Loss: 0.5309, Elapsed: 12:58:06
Epoch: 139, Loss: 0.4119, Elapsed: 13:03:45
Epoch: 140, Loss: 0.93898, Elapsed: 13:09:26
Epoch: 141, Loss: 0.45791, Elapsed: 13:15:04
Epoch: 142, Loss: 0.70093, Elapsed: 13:20:43
Epoch: 143, Loss: 0.84997, Elapsed: 13:26:23
Epoch: 144, Loss: 0.27732, Elapsed: 13:32:05
Epoch: 145, Loss: 0.51171, Elapsed: 13:37:44
Epoch: 146, Loss: 0.81123, Elapsed: 13:43:24
Epoch: 147, Loss: 0.5678, Elapsed: 13:49:04
Epoch: 148, Loss: 0.58568, Elapsed: 13:54:44
Epoch: 149, Loss: 0.3952, Elapsed: 14:00:23
Epoch: 150, Loss: 0.31967, Elapsed: 14:06:03
Epoch: 151, Loss: 0.44051, Elapsed: 14:11:46
Epoch: 152, Loss: 0.99278, Elapsed: 14:17:27
Epoch: 153, Loss: 0.87306, Elapsed: 14:23:07
Epoch: 154, Loss: 0.34008, Elapsed: 14:28:47
Epoch: 155, Loss: 0.4687, Elapsed: 14:34:27
Epoch: 156, Loss: 0.22836, Elapsed: 14:40:07
Epoch: 157, Loss: 0.23204, Elapsed: 14:45:48
Epoch: 158, Loss: 0.36854, Elapsed: 14:51:28
Epoch: 159, Loss: 0.35363, Elapsed: 14:57:08
Epoch: 160, Loss: 0.37937, Elapsed: 15:02:55
Epoch: 161, Loss: 0.7725, Elapsed: 15:08:36
Epoch: 162, Loss: 0.59353, Elapsed: 15:14:15
Epoch: 163, Loss: 0.57963, Elapsed: 15:19:54
Epoch: 164, Loss: 0.54625, Elapsed: 15:25:35
Epoch: 165, Loss: 0.65612, Elapsed: 15:31:15
Epoch: 166, Loss: 0.73254, Elapsed: 15:36:56
Epoch: 167, Loss: 0.4483, Elapsed: 15:42:37
Epoch: 168, Loss: 0.36817, Elapsed: 15:48:17
Epoch: 169, Loss: 0.57539, Elapsed: 15:53:57
Epoch: 170, Loss: 1.0026, Elapsed: 15:59:37
Epoch: 171, Loss: 0.95288, Elapsed: 16:05:17
Epoch: 172, Loss: 0.83053, Elapsed: 16:10:59
Epoch: 173, Loss: 0.41976, Elapsed: 16:16:39
Epoch: 174, Loss: 0.44098, Elapsed: 16:22:19
Epoch: 175, Loss: 0.58823, Elapsed: 16:28:00
Epoch: 176, Loss: 0.67325, Elapsed: 16:33:41
Epoch: 177, Loss: 0.27045, Elapsed: 16:39:21
Epoch: 178, Loss: 0.66652, Elapsed: 16:45:02
Epoch: 179, Loss: 1.0097, Elapsed: 16:50:43
Epoch: 180, Loss: 0.40372, Elapsed: 16:56:23
Epoch: 181, Loss: 0.39175, Elapsed: 17:02:04
Epoch: 182, Loss: 0.40741, Elapsed: 17:07:45
Epoch: 183, Loss: 0.35398, Elapsed: 17:13:25
Epoch: 184, Loss: 0.63228, Elapsed: 17:19:05
Epoch: 185, Loss: 0.35308, Elapsed: 17:24:45
Epoch: 186, Loss: 0.46854, Elapsed: 17:30:27
Epoch: 187, Loss: 0.51346, Elapsed: 17:36:08
Epoch: 188, Loss: 0.71886, Elapsed: 17:41:48
Epoch: 189, Loss: 0.73986, Elapsed: 17:47:29
Epoch: 190, Loss: 0.46669, Elapsed: 17:53:10
Epoch: 191, Loss: 0.40962, Elapsed: 17:58:51
Epoch: 192, Loss: 0.25007, Elapsed: 18:04:31
Epoch: 193, Loss: 0.45651, Elapsed: 18:10:12
Epoch: 194, Loss: 0.20788, Elapsed: 18:15:52
Epoch: 195, Loss: 0.32097, Elapsed: 18:21:32
Epoch: 196, Loss: 0.28159, Elapsed: 18:27:15
Epoch: 197, Loss: 0.20396, Elapsed: 18:32:56
Epoch: 198, Loss: 0.30823, Elapsed: 18:38:37
Epoch: 199, Loss: 0.28583, Elapsed: 18:44:18
Epoch: 200, Loss: 0.32877, Elapsed: 18:50:07
Объедините снимки состояния M сети, взятой во время обучения сформировать итоговый ансамбль. Предсказания ансамбля соответствуют среднему значению выхода полносвязного слоя из всех отдельных моделей M.
YPredictions = zeros(numClasses,numel(YTest),numSnapshots); modelAccuracy = zeros(numSnapshots+1,1); modelName = cell(numSnapshots+1,1); for m = 1:numSnapshots modelName{m} = modelPrefix + m*epochsPerSnapshot; load(modelName{m} + ".mat"); YPredictions(:,:,m) = gather(extractdata(predict(dlnet, dlarray(single(XTest),'SSCB')))); modelAccuracy(m) = computeAccuracy(YPredictions(:,:,m), YTest, classes); disp(modelName{m} + " accuracy: " + modelAccuracy(m) + "%") end
SnapshotEpoch40 accuracy: 88.04% SnapshotEpoch80 accuracy: 86.78% SnapshotEpoch120 accuracy: 87.53% SnapshotEpoch160 accuracy: 87.07% SnapshotEpoch200 accuracy: 88.39%
modelAccuracy(end) = computeAccuracy(mean(YPredictions,3), YTest, classes); modelName{end} = "Ensemble model"; disp("Ensemble accuracy: " + modelAccuracy(end) + "%")
Ensemble accuracy: 91.13%
Постройте точность на наборе тестовых данных для всех моделей снимка состояния и модели ансамбля.
figure;bar(modelAccuracy); ylabel('Accuracy (%)'); xticklabels(modelName) xtickangle(45) title('Model accuracy')
modelGradients
функционируйте берет в dlnetwork
объект dlnet
, мини-пакет входных данных dlX
, маркирует Y,
и параметр для затухания веса. Функция возвращает градиенты, потерю и состояние nonlearnable параметров. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция.
function [gradients, loss, state] = modelGradients(dlnet, dlX, Y, weightDecay) [dlYPred, state] = forward(dlnet, dlX); dlYPred = softmax(dlYPred); loss = crossentropy(dlYPred, Y); % L2-regularization (weight decay) allParams = dlnet.Learnables(dlnet.Learnables.Parameter == "Weights" | dlnet.Learnables.Parameter == "Scale",:).Value; l2Norm = cellfun(@(x) sum(x.^2,'All'), allParams, 'UniformOutput', false); l2Norm = sum(cat(1, l2Norm{:})); loss = loss + weightDecay*0.5*l2Norm; gradients = dlgradient(loss, dlnet.Learnables); end
computeAccuracy
функционируйте использует сетевые предсказания, истинные метки и количество классов, чтобы вычислить точность.
function accuracy = computeAccuracy(YPredictions, YTest, classes) [~,I] = max(YPredictions,[],1); C = classes(I); accuracy = 100*(sum(C==YTest)/numel(C)); end
plotLossAndLearnRate
графики функций потеря и изучают уровень в каждой итерации во время обучения.
function [lossLine, learnRateLine] = plotLossAndLearnRate() figure('Name','Training Progress'); clf subplot(2,1,1); lossLine = animatedline; title('Loss'); xlabel('Iteration') ylabel('Loss') grid on subplot(2,1,2); learnRateLine = animatedline; title('Learning rate'); xlabel('Iteration') ylabel('Learning rate') grid on end
convolutionalUnit(numF,stride,tag)
создает массив слоев с двумя сверточными слоями и соответствующей нормализацией партии. и слоями ReLU. numF
количество сверточных фильтров, stride
шаг первого сверточного слоя и tag
тег, который предварительно ожидается ко всем именам слоя.
function layers = convolutionalUnit(numF,stride,tag) layers = [ convolution2dLayer(3,numF,'Padding','same','Stride',stride,'Name',[tag,'conv1']) batchNormalizationLayer('Name',[tag,'BN1']) reluLayer('Name',[tag,'relu1']) convolution2dLayer(3,numF,'Padding','same','Name',[tag,'conv2']) batchNormalizationLayer('Name',[tag,'BN2'])]; end
[1] Он, Kaiming, Сянюй Чжан, Шаоцин Жэнь и Цзянь Сунь. "Глубокая невязка, учащаяся для распознавания изображений". В Продолжениях конференции по IEEE по компьютерному зрению и распознаванию образов, стр 770-778. 2016.
[2] Krizhevsky, Алекс. "Изучая несколько слоев функций от крошечных изображений". (2009). https://www.cs.toronto.edu / ~ kriz/learning-features-2009-TR.pdf
[3] Лощилов, Илья и Франк Хуттер. "Sgdr: Стохастический градиентный спуск с горячими перезапусками". (2016). arXiv предварительно распечатывают arXiv:1608.03983.
[4] Хуан, Гао, Йиксуэн Ли, Джефф Плейсс, Чжуан Лю, Джон Э. Хопкрофт и Килиан К. Вайнбергер. "Создайте снимки ансамбли: Обучите 1, получите m бесплатно". (2017). arXiv предварительно распечатывают arXiv:1704.00109.
dlarray
| dlfeval
| dlgradient
| dlnetwork
| layerGraph
| sgdmupdate
| sigmoid