Глубокое обучение Используя байесовую оптимизацию

В этом примере показано, как применить Байесовую оптимизацию к глубокому обучению и найти оптимальные сетевые гиперпараметры и опции обучения для сверточных нейронных сетей.

Чтобы обучить глубокую нейронную сеть, необходимо задать архитектуру нейронной сети, а также опции алгоритма настройки. Выбор и настройка этих гиперпараметров могут затруднить и занять время. Байесова оптимизация является алгоритмом, хорошо подходящим для оптимизации гиперпараметров моделей регрессии и классификации. Можно использовать Байесовую оптимизацию, чтобы оптимизировать функции, которые недифференцируемы, прерывисты, и длительны, чтобы оценить. Алгоритм внутренне обеспечивает Гауссову модель процесса целевой функции и использует оценки целевой функции, чтобы обучить эту модель.

В этом примере показано, как к:

  • Загрузите и подготовьте набор данных CIFAR-10 к сетевому обучению. Этот набор данных является одним из наиболее широко используемых наборов данных для тестирования моделей классификации изображений.

  • Задайте переменные, чтобы оптимизировать использующую Байесовую оптимизацию. Эти переменные являются опциями алгоритма настройки, а также параметрами самой сетевой архитектуры.

  • Задайте целевую функцию, которая принимает значения переменных оптимизации как входные параметры, задает сетевую архитектуру и опции обучения, обучает и проверяет сеть и сохраняет обучивший сеть на диск. Целевая функция задана в конце этого скрипта.

  • Выполните Байесовую оптимизацию путем минимизации ошибки классификации на наборе валидации.

  • Загрузите лучшую сеть от диска и оцените его на наборе тестов.

Как альтернатива, можно использовать Байесовую оптимизацию, чтобы найти оптимальные опции обучения в Experiment Manager. Для получения дополнительной информации смотрите Гиперпараметры Эксперимента Мелодии при помощи Байесовой Оптимизации.

Подготовка данных

Загрузите набор данных CIFAR-10 [1]. Этот набор данных содержит 60 000 изображений, и каждое изображение имеет размер 32 32 и три цветовых канала (RGB). Размер целого набора данных составляет 175 Мбайт. В зависимости от вашего интернет-соединения может занять время процесс загрузки.

datadir = tempdir;
downloadCIFARData(datadir);

Загрузите набор данных CIFAR-10, когда обучение отображает и помечает, и тестовые изображения и метки. Чтобы включить сетевую валидацию, используйте 5000 из тестовых изображений для валидации.

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);

idx = randperm(numel(YTest),5000);
XValidation = XTest(:,:,:,idx);
XTest(:,:,:,idx) = [];
YValidation = YTest(idx);
YTest(idx) = [];

Можно отобразить выборку учебных изображений с помощью следующего кода.

figure;
idx = randperm(numel(YTrain),20);
for i = 1:numel(idx)
    subplot(4,5,i);
    imshow(XTrain(:,:,:,idx(i)));
end

Выберите Variables to Optimize

Выберите который переменные оптимизировать использующую Байесовую оптимизацию и указать диапазоны, чтобы искать в. Кроме того, задайте, являются ли переменные целыми числами и искать ли интервал на логарифмическом пробеле. Оптимизируйте следующие переменные:

  • Сетевая глубина раздела. Этот параметр управляет глубиной сети. Сеть имеет три раздела, каждого с SectionDepth идентичные сверточные слои. Таким образом, общим количеством сверточных слоев является 3*SectionDepth. Целевая функция позже в скрипте берет количество сверточных фильтров в каждом слое, пропорциональном 1/sqrt(SectionDepth). В результате количество параметров и необходимое количество расчета для каждой итерации являются примерно тем же самым для различных глубин раздела.

  • Начальная скорость обучения. Лучшая скорость обучения может зависеть от ваших данных, а также сети, вы - обучение.

  • Стохастический импульс градиентного спуска. Импульс добавляет, что инерция к обновлениям параметра при наличии текущего обновления содержит вклад, пропорциональный обновлению в предыдущей итерации. Это приводит к более сглаженным обновлениям параметра и сокращению шума, свойственного к стохастическому градиентному спуску.

  • Сила регуляризации L2. Используйте регуляризацию, чтобы предотвратить сверхподбор кривой. Ищите пробел силы регуляризации, чтобы найти хорошее значение. Увеличение данных и нормализация партии. также помогают упорядочить сеть.

optimVars = [
    optimizableVariable('SectionDepth',[1 3],'Type','integer')
    optimizableVariable('InitialLearnRate',[1e-2 1],'Transform','log')
    optimizableVariable('Momentum',[0.8 0.98])
    optimizableVariable('L2Regularization',[1e-10 1e-2],'Transform','log')];

Выполните байесовую оптимизацию

Создайте целевую функцию для Байесового оптимизатора, с помощью данных об обучении и валидации в качестве входных параметров. Целевая функция обучает сверточную нейронную сеть и возвращает ошибку классификации на наборе валидации. Эта функция задана в конце этого скрипта. Поскольку bayesopt использует коэффициент ошибок набора валидации, чтобы выбрать лучшую модель, возможно, что итоговая сеть сверхсоответствует на наборе валидации. Финал выбранная модель затем тестируется на независимом наборе тестов, чтобы оценить ошибку обобщения.

ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation);

Выполните Байесовую оптимизацию путем минимизации ошибки классификации на наборе валидации. Задайте общее время оптимизации в секундах. Чтобы лучше всего использовать энергию Байесовой оптимизации, необходимо выполнить по крайней мере 30 оценок целевой функции. Чтобы обучить нейронные сети параллельно на нескольких графических процессорах, установите 'UseParallel' значение к true. Если вы имеете один графический процессор и устанавливаете 'UseParallel' значение к true, затем все рабочие совместно используют тот графический процессор, и вы не получаете учебного ускорения и увеличиваете возможности графического процессора, исчерпывающего память.

После того, как каждая сеть закончила обучение, bayesopt распечатывает результаты к командному окну. bayesopt функционируйте затем возвращает имена файлов в BayesObject.UserDataTrace. Целевая функция сохраняет обучивший нейронные сети на диск и возвращает имена файлов в bayesopt.

BayesObject = bayesopt(ObjFcn,optimVars, ...
    'MaxTime',14*60*60, ...
    'IsObjectiveDeterministic',false, ...
    'UseParallel',false);
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|    1 | Best   |       0.197 |      955.69 |       0.197 |       0.197 |            3 |      0.61856 |      0.80624 |   0.00035179 |
|    2 | Best   |      0.1918 |      790.38 |      0.1918 |     0.19293 |            2 |     0.074118 |      0.91031 |   2.7229e-09 |
|    3 | Accept |      0.2438 |      660.29 |      0.1918 |     0.19344 |            1 |     0.051153 |      0.90911 |   0.00043113 |
|    4 | Accept |       0.208 |      672.81 |      0.1918 |      0.1918 |            1 |      0.70138 |      0.81923 |   3.7783e-08 |
|    5 | Best   |      0.1792 |      844.07 |      0.1792 |     0.17921 |            2 |      0.65156 |      0.93783 |   3.3663e-10 |
|    6 | Best   |      0.1776 |      851.49 |      0.1776 |     0.17759 |            2 |      0.23619 |      0.91932 |   1.0007e-10 |
|    7 | Accept |      0.2232 |       883.5 |      0.1776 |     0.17759 |            2 |     0.011147 |      0.91526 |    0.0099842 |
|    8 | Accept |      0.2508 |      822.65 |      0.1776 |     0.17762 |            1 |     0.023919 |      0.91048 |   1.0002e-10 |
|    9 | Accept |      0.1974 |      1947.6 |      0.1776 |     0.17761 |            3 |     0.010017 |      0.97683 |   5.4603e-10 |
|   10 | Best   |       0.176 |      1938.4 |       0.176 |     0.17608 |            2 |       0.3526 |      0.82381 |   1.4244e-07 |
|   11 | Accept |      0.1914 |      2874.4 |       0.176 |     0.17608 |            3 |     0.079847 |      0.86801 |   9.7335e-07 |
|   12 | Accept |       0.181 |        2578 |       0.176 |     0.17809 |            2 |      0.35141 |      0.80202 |   4.5634e-08 |
|   13 | Accept |      0.1838 |      2410.8 |       0.176 |     0.17946 |            2 |      0.39508 |      0.95968 |   9.3856e-06 |
|   14 | Accept |      0.1786 |      2490.6 |       0.176 |     0.17737 |            2 |      0.44857 |      0.91827 |   1.0939e-10 |
|   15 | Accept |      0.1776 |        2668 |       0.176 |     0.17751 |            2 |      0.95793 |      0.85503 |   1.0222e-05 |
|   16 | Accept |      0.1824 |      3059.8 |       0.176 |     0.17812 |            2 |      0.41142 |      0.86931 |    1.447e-06 |
|   17 | Accept |      0.1894 |      3091.5 |       0.176 |     0.17982 |            2 |      0.97051 |      0.80284 |   1.5836e-10 |
|   18 | Accept |       0.217 |      2794.5 |       0.176 |     0.17989 |            1 |       0.2464 |      0.84428 |   4.4938e-06 |
|   19 | Accept |      0.2358 |      4054.2 |       0.176 |     0.17601 |            3 |      0.22843 |       0.9454 |   0.00098248 |
|   20 | Accept |      0.2216 |      4411.7 |       0.176 |     0.17601 |            3 |     0.010847 |      0.82288 |   2.4756e-08 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|   21 | Accept |      0.2038 |      3906.4 |       0.176 |     0.17601 |            2 |      0.09885 |      0.81541 |    0.0021184 |
|   22 | Accept |      0.2492 |      4103.4 |       0.176 |     0.17601 |            2 |      0.52313 |      0.83139 |    0.0016269 |
|   23 | Accept |      0.1814 |      4240.5 |       0.176 |     0.17601 |            2 |      0.29506 |      0.84061 |   6.0203e-10 |

__________________________________________________________
Optimization completed.
MaxTime of 50400 seconds reached.
Total function evaluations: 23
Total elapsed time: 53088.5123 seconds
Total objective function evaluation time: 53050.7026

Best observed feasible point:
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         2               0.3526         0.82381        1.4244e-07   

Observed objective function value = 0.176
Estimated objective function value = 0.17601
Function evaluation time = 1938.4483

Best estimated feasible point (according to models):
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         2               0.3526         0.82381        1.4244e-07   

Estimated objective function value = 0.17601
Estimated function evaluation time = 1898.2641

Оцените итоговую сеть

Загрузите лучшую сеть, найденную в оптимизации и ее точности валидации.

bestIdx = BayesObject.IndexOfMinimumTrace(end);
fileName = BayesObject.UserDataTrace{bestIdx};
savedStruct = load(fileName);
valError = savedStruct.valError
valError = 0.1760

Предскажите метки набора тестов и вычислите тестовую ошибку. Обработайте классификацию каждого изображения в наборе тестов как независимые события с определенной вероятностью успеха, что означает, что количество неправильно классифицированных изображений следует за биномиальным распределением. Используйте это, чтобы вычислить стандартную погрешность (testErrorSE) и аппроксимированный 95%-й доверительный интервал (testError95CI) из коэффициента ошибок обобщения. Этот метод часто называется Вальдовым методом. bayesopt определяет лучшую сеть с помощью набора валидации, не подвергая сеть набору тестов. Затем возможно, что тестовая ошибка выше, чем ошибка валидации.

[YPredicted,probs] = classify(savedStruct.trainedNet,XTest);
testError = 1 - mean(YPredicted == YTest)
testError = 0.1910
NTest = numel(YTest);
testErrorSE = sqrt(testError*(1-testError)/NTest);
testError95CI = [testError - 1.96*testErrorSE, testError + 1.96*testErrorSE]
testError95CI = 1×2

    0.1801    0.2019

Постройте матрицу беспорядка для тестовых данных. Отобразите точность и отзыв для каждого класса при помощи сводных данных строки и столбца.

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(YTest,YPredicted);
cm.Title = 'Confusion Matrix for Test Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Можно отобразить некоторые тестовые изображения вместе с их предсказанными классами и вероятностями тех классов с помощью следующего кода.

figure
idx = randperm(numel(YTest),9);
for i = 1:numel(idx)
    subplot(3,3,i)
    imshow(XTest(:,:,:,idx(i)));
    prob = num2str(100*max(probs(idx(i),:)),3);
    predClass = char(YPredicted(idx(i)));
    label = [predClass,', ',prob,'%'];
    title(label)
end

Целевая функция для оптимизации

Задайте целевую функцию для оптимизации. Эта функция выполняет следующие шаги:

  1. Принимает значения переменных оптимизации как входные параметры. bayesopt вызывает целевую функцию с текущими значениями переменных оптимизации в таблице с каждым именем столбца, равным имени переменной. Например, текущим значением сетевой глубины раздела является optVars.SectionDepth.

  2. Задает сетевую архитектуру и опции обучения.

  3. Обучает и проверяет сеть.

  4. Сохраняет обучивший сеть, ошибку валидации и опции обучения к диску.

  5. Возвращает ошибку валидации и имя файла сохраненной сети.

function ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation)
ObjFcn = @valErrorFun;
    function [valError,cons,fileName] = valErrorFun(optVars)

Задайте архитектуру сверточной нейронной сети.

  • Добавьте дополнение в сверточные слои так, чтобы пространственный выходной размер всегда был тем же самым как входным размером.

  • Каждый раз вы прореживаете пространственные размерности на коэффициент двух использований, макс. объединяющих слои, увеличиваете число фильтров на коэффициент два. Выполнение так гарантирует, что объем расчета, требуемого в каждом сверточном слое, является примерно тем же самым.

  • Выберите количество фильтров, пропорциональных 1/sqrt(SectionDepth), так, чтобы сети различных глубин имели примерно то же количество параметров и потребовали о то же самом значении расчета на итерацию. Чтобы увеличить число сетевых параметров и полной сетевой гибкости, увеличьте numF. Чтобы обучить еще более глубокие сети, измените область значений SectionDepth переменная.

  • Используйте convBlock(filterSize,numFilters,numConvLayers) создать блок numConvLayers сверточные слои, каждый с заданным filterSize и numFilters фильтры и каждый сопровождаемый слоем нормализации партии. и слоем ReLU. convBlock функция задана в конце этого примера.

        imageSize = [32 32 3];
        numClasses = numel(unique(YTrain));
        numF = round(16/sqrt(optVars.SectionDepth));
        layers = [
            imageInputLayer(imageSize)
            
            % The spatial input and output sizes of these convolutional
            % layers are 32-by-32, and the following max pooling layer
            % reduces this to 16-by-16.
            convBlock(3,numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 16-by-16, and the following max pooling layer
            % reduces this to 8-by-8.
            convBlock(3,2*numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 8-by-8. The global average pooling layer averages
            % over the 8-by-8 inputs, giving an output of size
            % 1-by-1-by-4*initialNumFilters. With a global average
            % pooling layer, the final classification output is only
            % sensitive to the total amount of each feature present in the
            % input image, but insensitive to the spatial positions of the
            % features.
            convBlock(3,4*numF,optVars.SectionDepth)
            averagePooling2dLayer(8)
            
            % Add the fully connected layer and the final softmax and
            % classification layers.
            fullyConnectedLayer(numClasses)
            softmaxLayer
            classificationLayer];

Задайте опции для сетевого обучения. Оптимизируйте начальную скорость обучения, импульс SGD и силу регуляризации L2.

Задайте данные о валидации и выберите 'ValidationFrequency' оцените таким образом что trainNetwork проверяет сеть однажды в эпоху. Обучайтесь для постоянного числа эпох и понизьте скорость обучения на коэффициент 10 в течение прошлых эпох. Это уменьшает шум обновлений параметра и позволяет сетевым параметрам успокоиться ближе к минимуму функции потерь.

        miniBatchSize = 256;
        validationFrequency = floor(numel(YTrain)/miniBatchSize);
        options = trainingOptions('sgdm', ...
            'InitialLearnRate',optVars.InitialLearnRate, ...
            'Momentum',optVars.Momentum, ...
            'MaxEpochs',60, ...
            'LearnRateSchedule','piecewise', ...
            'LearnRateDropPeriod',40, ...
            'LearnRateDropFactor',0.1, ...
            'MiniBatchSize',miniBatchSize, ...
            'L2Regularization',optVars.L2Regularization, ...
            'Shuffle','every-epoch', ...
            'Verbose',false, ...
            'Plots','training-progress', ...
            'ValidationData',{XValidation,YValidation}, ...
            'ValidationFrequency',validationFrequency);

Используйте увеличение данных, чтобы случайным образом инвертировать учебные изображения вдоль вертикальной оси, и случайным образом перевести их до четырех пикселей горизонтально и вертикально. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

        pixelRange = [-4 4];
        imageAugmenter = imageDataAugmenter( ...
            'RandXReflection',true, ...
            'RandXTranslation',pixelRange, ...
            'RandYTranslation',pixelRange);
        datasource = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

Обучите сеть и постройте процесс обучения во время обучения. Закройте все учебные графики после того, как обучение закончится.

        trainedNet = trainNetwork(datasource,layers,options);
        close(findall(groot,'Tag','NNET_CNN_TRAININGPLOT_UIFIGURE'))

Оцените обучивший сеть на наборе валидации, вычислите предсказанные метки изображения и вычислите коэффициент ошибок данных о валидации.

        YPredicted = classify(trainedNet,XValidation);
        valError = 1 - mean(YPredicted == YValidation);

Создайте имя файла, содержащее ошибку валидации, и сохраните сеть, ошибку валидации и опции обучения к диску. Целевая функция возвращает fileName как выходной аргумент и bayesopt возвращает все имена файлов в BayesObject.UserDataTrace. Дополнительный необходимый выходной аргумент cons задает ограничения среди переменных. Нет никаких переменных ограничений.

        fileName = num2str(valError) + ".mat";
        save(fileName,'trainedNet','valError','options')
        cons = [];
        
    end
end

convBlock функция создает блок numConvLayers сверточные слои, каждый с заданным filterSize и numFilters фильтры и каждый сопровождаемый слоем нормализации партии. и слоем ReLU.

function layers = convBlock(filterSize,numFilters,numConvLayers)
layers = [
    convolution2dLayer(filterSize,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer];
layers = repmat(layers,numConvLayers,1);
end

Ссылки

[1] Krizhevsky, Алекс. "Изучая несколько слоев функций от крошечных изображений". (2009). https://www.cs.toronto.edu / ~ kriz/learning-features-2009-TR.pdf

Смотрите также

| | | (Statistics and Machine Learning Toolbox)

Похожие темы