Глубокое обучение с использованием байесовской оптимизации

Этот пример показывает, как применить байесовскую оптимизацию к глубокому обучению и найти оптимальные сетевые гиперпараметры и опции обучения для сверточных нейронных сетей.

Для обучения глубокой нейронной сети необходимо указать архитектуру нейронной сети, а также опции алгоритма настройки. Выбор и настройка этих гиперпараметров могут оказаться трудными и потребовать времени. Байесовская оптимизация является алгоритмом, хорошо подходящим для оптимизации гиперпараметров классификационных и регрессионых моделей. Можно использовать оптимизацию Байеса, чтобы оптимизировать функции, которые являются недифференцируемыми, прерывистыми и длительными для оценки. Алгоритм внутренне поддерживает Гауссову модель процесса целевой функции и использует вычисления целевой функции, чтобы обучить эту модель.

В этом примере показано, как:

  • Загрузите и подготовьте CIFAR-10 набор данных для сетевого обучения. Этот набор данных является одним из наиболее широко используемых наборов данных для проверки моделей классификации изображений.

  • Задайте переменные для оптимизации с помощью байесовской оптимизации. Эти переменные являются опциями алгоритма настройки, а также параметрами самой сетевой архитектуры.

  • Определите целевую функцию, которая принимает значения переменных оптимизации в качестве входов, задает архитектуру сети и опции обучения, обучает и проверяет сеть, и сохраняет обученную сеть на диске. Целевая функция определяется в конце этого скрипта.

  • Выполните байесовскую оптимизацию путем минимизации ошибки классификации на наборе валидации.

  • Загрузите лучшую сеть с диска и оцените ее на тестовом наборе.

В качестве альтернативы можно использовать байесовскую оптимизацию, чтобы найти оптимальные опции обучения в Experiment Manager. Для получения дополнительной информации см. «Настройка гиперпараметров эксперимента при помощи байесовской оптимизации».

Подготовка данных

Загрузите CIFAR-10 набор данных [1]. Этот набор данных содержит 60 000 изображений, и каждое изображение имеет размер 32 на 32 и три цветовых канала (RGB). Размер всего набора данных составляет 175 МБ. В зависимости от вашего подключения к Интернету процесс загрузки может занять некоторое время.

datadir = tempdir;
downloadCIFARData(datadir);

Загрузите набор CIFAR-10 данных в виде обучающих изображений и меток, а также тестовых изображений и меток. Чтобы включить валидацию сети, используйте для валидации 5000 тестовых изображений.

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);

idx = randperm(numel(YTest),5000);
XValidation = XTest(:,:,:,idx);
XTest(:,:,:,idx) = [];
YValidation = YTest(idx);
YTest(idx) = [];

Отобразить выборку обучающих изображений можно используя следующий код.

figure;
idx = randperm(numel(YTrain),20);
for i = 1:numel(idx)
    subplot(4,5,i);
    imshow(XTrain(:,:,:,idx(i)));
end

Выберите переменные для оптимизации

Выберите переменные для оптимизации с помощью байесовской оптимизации и укажите области значений для поиска. Кроме того, задайте, являются ли переменные целыми числами и нужно ли искать интервал в логарифмическом пространстве. Оптимизируйте следующие переменные:

  • Глубина разреза сети. Этот параметр управляет глубиной сети. Сеть имеет три участка, каждый с SectionDepth идентичные сверточные слои. Таким образом, общее количество сверточных слоев 3*SectionDepth. Целевая функция позже в скрипте принимает количество сверточных фильтров на каждом слое, пропорциональное 1/sqrt(SectionDepth). В результате количество параметров и необходимый объем расчетов для каждой итерации примерно одинаковы для различных глубин сечения.

  • Начальная скорость обучения. Лучшая скорость обучения может зависеть от ваших данных, а также от сети, которую вы обучаете.

  • Импульс стохастического градиентного спуска. Импульс добавляет инерцию обновлениям параметров, имея текущее обновление, содержащее вклад, пропорциональный обновлению в предыдущей итерации. Это приводит к более плавным обновлениям параметров и уменьшению шума, присущего стохастическому градиентному спуску.

  • L2 регуляризационную прочность. Используйте регуляризацию, чтобы предотвратить сверхподбор кривой. Поиск пространства силы регуляризации, чтобы найти хорошее значение. Увеличение количества данных и нормализация партии . также помогают упорядочить сеть.

optimVars = [
    optimizableVariable('SectionDepth',[1 3],'Type','integer')
    optimizableVariable('InitialLearnRate',[1e-2 1],'Transform','log')
    optimizableVariable('Momentum',[0.8 0.98])
    optimizableVariable('L2Regularization',[1e-10 1e-2],'Transform','log')];

Выполните байесову оптимизацию

Создайте целевую функцию для байесовского оптимизатора, используя данные обучения и валидации в качестве входов. Целевая функция обучает сверточную нейронную сеть и возвращает ошибку классификации на наборе валидации. Эта функция определяется в конце этого скрипта. Потому что bayesopt использует частоту ошибок на наборе валидации, чтобы выбрать лучшую модель, возможно, что конечная сеть перегружается на наборе валидации. Окончательная выбранная модель затем проверяется на независимом наборе тестов, чтобы оценить ошибку обобщения.

ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation);

Выполните байесовскую оптимизацию путем минимизации ошибки классификации на наборе валидации. Задайте общее время оптимизации в секундах. Чтобы наилучшим образом использовать степень байесовской оптимизации, вы должны выполнить не менее 30 вычислений целевой функции. Чтобы обучать сети параллельно на нескольких графических процессорах, установите 'UseParallel' значение в true. Если у вас есть один графический процессор и вы устанавливаете 'UseParallel' значение в true, затем все рабочие разделяют этот графический процессор, и вы не получаете никакой скорости обучения и увеличиваете шансы на то, что у GPU закончится память.

После того, как каждая сеть закончит обучение, bayesopt печать результатов в командное окно. The bayesopt затем функция возвращает имена файлов в BayesObject.UserDataTrace. Целевая функция сохраняет обученные сети на диск и возвращает имена файлов в bayesopt.

BayesObject = bayesopt(ObjFcn,optimVars, ...
    'MaxTime',14*60*60, ...
    'IsObjectiveDeterministic',false, ...
    'UseParallel',false);
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|    1 | Best   |       0.197 |      955.69 |       0.197 |       0.197 |            3 |      0.61856 |      0.80624 |   0.00035179 |
|    2 | Best   |      0.1918 |      790.38 |      0.1918 |     0.19293 |            2 |     0.074118 |      0.91031 |   2.7229e-09 |
|    3 | Accept |      0.2438 |      660.29 |      0.1918 |     0.19344 |            1 |     0.051153 |      0.90911 |   0.00043113 |
|    4 | Accept |       0.208 |      672.81 |      0.1918 |      0.1918 |            1 |      0.70138 |      0.81923 |   3.7783e-08 |
|    5 | Best   |      0.1792 |      844.07 |      0.1792 |     0.17921 |            2 |      0.65156 |      0.93783 |   3.3663e-10 |
|    6 | Best   |      0.1776 |      851.49 |      0.1776 |     0.17759 |            2 |      0.23619 |      0.91932 |   1.0007e-10 |
|    7 | Accept |      0.2232 |       883.5 |      0.1776 |     0.17759 |            2 |     0.011147 |      0.91526 |    0.0099842 |
|    8 | Accept |      0.2508 |      822.65 |      0.1776 |     0.17762 |            1 |     0.023919 |      0.91048 |   1.0002e-10 |
|    9 | Accept |      0.1974 |      1947.6 |      0.1776 |     0.17761 |            3 |     0.010017 |      0.97683 |   5.4603e-10 |
|   10 | Best   |       0.176 |      1938.4 |       0.176 |     0.17608 |            2 |       0.3526 |      0.82381 |   1.4244e-07 |
|   11 | Accept |      0.1914 |      2874.4 |       0.176 |     0.17608 |            3 |     0.079847 |      0.86801 |   9.7335e-07 |
|   12 | Accept |       0.181 |        2578 |       0.176 |     0.17809 |            2 |      0.35141 |      0.80202 |   4.5634e-08 |
|   13 | Accept |      0.1838 |      2410.8 |       0.176 |     0.17946 |            2 |      0.39508 |      0.95968 |   9.3856e-06 |
|   14 | Accept |      0.1786 |      2490.6 |       0.176 |     0.17737 |            2 |      0.44857 |      0.91827 |   1.0939e-10 |
|   15 | Accept |      0.1776 |        2668 |       0.176 |     0.17751 |            2 |      0.95793 |      0.85503 |   1.0222e-05 |
|   16 | Accept |      0.1824 |      3059.8 |       0.176 |     0.17812 |            2 |      0.41142 |      0.86931 |    1.447e-06 |
|   17 | Accept |      0.1894 |      3091.5 |       0.176 |     0.17982 |            2 |      0.97051 |      0.80284 |   1.5836e-10 |
|   18 | Accept |       0.217 |      2794.5 |       0.176 |     0.17989 |            1 |       0.2464 |      0.84428 |   4.4938e-06 |
|   19 | Accept |      0.2358 |      4054.2 |       0.176 |     0.17601 |            3 |      0.22843 |       0.9454 |   0.00098248 |
|   20 | Accept |      0.2216 |      4411.7 |       0.176 |     0.17601 |            3 |     0.010847 |      0.82288 |   2.4756e-08 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|   21 | Accept |      0.2038 |      3906.4 |       0.176 |     0.17601 |            2 |      0.09885 |      0.81541 |    0.0021184 |
|   22 | Accept |      0.2492 |      4103.4 |       0.176 |     0.17601 |            2 |      0.52313 |      0.83139 |    0.0016269 |
|   23 | Accept |      0.1814 |      4240.5 |       0.176 |     0.17601 |            2 |      0.29506 |      0.84061 |   6.0203e-10 |

__________________________________________________________
Optimization completed.
MaxTime of 50400 seconds reached.
Total function evaluations: 23
Total elapsed time: 53088.5123 seconds
Total objective function evaluation time: 53050.7026

Best observed feasible point:
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         2               0.3526         0.82381        1.4244e-07   

Observed objective function value = 0.176
Estimated objective function value = 0.17601
Function evaluation time = 1938.4483

Best estimated feasible point (according to models):
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         2               0.3526         0.82381        1.4244e-07   

Estimated objective function value = 0.17601
Estimated function evaluation time = 1898.2641

Оценка конечной сети

Загрузите лучшую сеть, найденную в оптимизации и ее точности валидации.

bestIdx = BayesObject.IndexOfMinimumTrace(end);
fileName = BayesObject.UserDataTrace{bestIdx};
savedStruct = load(fileName);
valError = savedStruct.valError
valError = 0.1760

Спрогнозируйте метки тестового набора и вычислите ошибку тестирования. Относитесь к классификации каждого изображения в тестовом наборе как к независимым событиям с определенной вероятностью успеха, что означает, что количество неправильно классифицированных изображений следует биномиальному распределению. Используйте это для вычисления стандартной ошибки (testErrorSE) и приблизительно 95% доверительный интервал (testError95CI) частоты ошибок обобщения. Этот метод часто называют методом Вальда. bayesopt определяет лучшую сеть, используя набор валидации, не подвергая сеть тестовому набору. Тогда возможно, что ошибка тестирования выше, чем ошибка валидации.

[YPredicted,probs] = classify(savedStruct.trainedNet,XTest);
testError = 1 - mean(YPredicted == YTest)
testError = 0.1910
NTest = numel(YTest);
testErrorSE = sqrt(testError*(1-testError)/NTest);
testError95CI = [testError - 1.96*testErrorSE, testError + 1.96*testErrorSE]
testError95CI = 1×2

    0.1801    0.2019

Постройте график матрицы неточностей для тестовых данных. Отображение точности и отзыва для каждого класса с помощью сводных данных по столбцам и строкам.

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(YTest,YPredicted);
cm.Title = 'Confusion Matrix for Test Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

Вы можете отобразить некоторые тестовые изображения вместе с их предсказанными классами и вероятностями этих классов, используя следующий код.

figure
idx = randperm(numel(YTest),9);
for i = 1:numel(idx)
    subplot(3,3,i)
    imshow(XTest(:,:,:,idx(i)));
    prob = num2str(100*max(probs(idx(i),:)),3);
    predClass = char(YPredicted(idx(i)));
    label = [predClass,', ',prob,'%'];
    title(label)
end

Целевая функция для оптимизации

Определите целевую функцию для оптимизации. Эта функция выполняет следующие шаги:

  1. Принимает значения переменных оптимизации как входы. bayesopt вызывает целевую функцию с текущими значениями переменных оптимизации в таблице с именем каждого столбца, равным имени переменной. Например, текущее значение глубины сечения сети optVars.SectionDepth.

  2. Определяет сетевую архитектуру и опции обучения.

  3. Обучает и проверяет сеть.

  4. Сохраняет обученную сеть, ошибку валидации и опции обучения на диске.

  5. Возвращает валидацию ошибку и имя файла сохраненной сети.

function ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation)
ObjFcn = @valErrorFun;
    function [valError,cons,fileName] = valErrorFun(optVars)

Определите архитектуру сверточной нейронной сети.

  • Добавьте заполнение к сверточным слоям так, чтобы пространственный выходной размер всегда совпадал с размером входа.

  • Каждый раз, когда вы понижаете-дискретизируете пространственные размерности в два раза с использованием максимальных слоев объединения, увеличивайте количество фильтров в два раза. Это гарантирует, что количество расчетов, требуемых в каждом сверточном слое, примерно одинаково.

  • Выберите количество фильтров, пропорциональных 1/sqrt(SectionDepth), так что сети различной глубины имеют примерно одинаковое количество параметров и требуют примерно одинакового количества расчета на итерацию. Чтобы увеличить количество параметров сети и общую гибкость сети, увеличьте numF. Чтобы обучить еще более глубокие сети, измените область значений SectionDepth переменная.

  • Использование convBlock(filterSize,numFilters,numConvLayers) чтобы создать блок numConvLayers сверточные слои, каждый с заданным filterSize и numFilters фильтры, каждый из которых сопровождается слоем нормализации партии . и слоем ReLU. The convBlock функция определяется в конце этого примера.

        imageSize = [32 32 3];
        numClasses = numel(unique(YTrain));
        numF = round(16/sqrt(optVars.SectionDepth));
        layers = [
            imageInputLayer(imageSize)
            
            % The spatial input and output sizes of these convolutional
            % layers are 32-by-32, and the following max pooling layer
            % reduces this to 16-by-16.
            convBlock(3,numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 16-by-16, and the following max pooling layer
            % reduces this to 8-by-8.
            convBlock(3,2*numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 8-by-8. The global average pooling layer averages
            % over the 8-by-8 inputs, giving an output of size
            % 1-by-1-by-4*initialNumFilters. With a global average
            % pooling layer, the final classification output is only
            % sensitive to the total amount of each feature present in the
            % input image, but insensitive to the spatial positions of the
            % features.
            convBlock(3,4*numF,optVars.SectionDepth)
            averagePooling2dLayer(8)
            
            % Add the fully connected layer and the final softmax and
            % classification layers.
            fullyConnectedLayer(numClasses)
            softmaxLayer
            classificationLayer];

Укажите опции сетевого обучения. Оптимизируйте начальную скорость обучения, импульс SGD и L2 силу регуляризации.

Укажите данные валидации и выберите 'ValidationFrequency' значение, такое что trainNetwork проверяет сеть один раз в эпоху. Train для фиксированного числа эпох и опустите скорость обучения в 10 раз в течение последних эпох. Это уменьшает шум обновлений параметров и позволяет сетевым параметрам остановиться ближе к минимуму функции потерь.

        miniBatchSize = 256;
        validationFrequency = floor(numel(YTrain)/miniBatchSize);
        options = trainingOptions('sgdm', ...
            'InitialLearnRate',optVars.InitialLearnRate, ...
            'Momentum',optVars.Momentum, ...
            'MaxEpochs',60, ...
            'LearnRateSchedule','piecewise', ...
            'LearnRateDropPeriod',40, ...
            'LearnRateDropFactor',0.1, ...
            'MiniBatchSize',miniBatchSize, ...
            'L2Regularization',optVars.L2Regularization, ...
            'Shuffle','every-epoch', ...
            'Verbose',false, ...
            'Plots','training-progress', ...
            'ValidationData',{XValidation,YValidation}, ...
            'ValidationFrequency',validationFrequency);

Используйте увеличение данных, чтобы случайным образом развернуть обучающие изображения вдоль вертикальной оси и случайным образом переместить их до четырех пикселей горизонтально и вертикально. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

        pixelRange = [-4 4];
        imageAugmenter = imageDataAugmenter( ...
            'RandXReflection',true, ...
            'RandXTranslation',pixelRange, ...
            'RandYTranslation',pixelRange);
        datasource = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

Обучите сеть и постройте график процесса обучения во время обучения. Закройте все обучающие графики после окончания обучения.

        trainedNet = trainNetwork(datasource,layers,options);
        close(findall(groot,'Tag','NNET_CNN_TRAININGPLOT_UIFIGURE'))

Оцените обученную сеть на наборе валидации, вычислите предсказанные метки изображений и вычислите вероятность ошибок на данных валидации.

        YPredicted = classify(trainedNet,XValidation);
        valError = 1 - mean(YPredicted == YValidation);

Создайте имя файла, содержащее ошибку валидации, и сохраните сеть, ошибку валидации и опции обучения на диске. Целевая функция возвращается fileName в качестве выходного аргумента и bayesopt возвращает все имена файлов в BayesObject.UserDataTrace. Дополнительный необходимый выходной аргумент cons задает ограничения среди переменных. Переменных ограничений нет.

        fileName = num2str(valError) + ".mat";
        save(fileName,'trainedNet','valError','options')
        cons = [];
        
    end
end

The convBlock функция создает блок numConvLayers сверточные слои, каждый с заданным filterSize и numFilters фильтры, каждый из которых сопровождается слоем нормализации партии . и слоем ReLU.

function layers = convBlock(filterSize,numFilters,numConvLayers)
layers = [
    convolution2dLayer(filterSize,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer];
layers = repmat(layers,numConvLayers,1);
end

Ссылки

[1] Крижевский, Алекс. «Изучение нескольких слоев функций из крошечных изображений». (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

См. также

| | | (Statistics and Machine Learning Toolbox)

Похожие темы