Использование parfeval Обучение нескольких Нейронных сетей для глубокого обучения

В этом примере показано, как использовать parfeval выполнить зачистку параметра на глубине сетевой архитектуры для нейронной сети для глубокого обучения и извлечь данные во время обучения.

Обучение глубокому обучению часто занимает часы или дни, и поиск хорошей архитектуры может оказаться трудным. С помощью параллельных вычислений можно ускорить и автоматизировать поиск хороших моделей. Если у вас есть доступ к машине с несколькими графическими модулями (GPU), можно завершить этот пример на локальной копии набора данных с локальным параллельным пулом. Если вы хотите использовать больше ресурсов, можно расширить процесс глубокого обучения до облака. В этом примере показано, как использовать parfeval для выполнения протягивания параметра на глубине сетевой архитектуры в кластере в облаке. Использование parfeval позволяет вам обучаться в фоновом режиме, не блокируя MATLAB, и предоставляет опции, чтобы остановить раньше, если результаты удовлетворительны. Можно изменить скрипт, чтобы выполнить сдвиг параметра по любому другому параметру. Кроме того, этот пример показывает, как получить обратную связь от работников во время расчетов с помощью DataQueue.

Требования

Прежде чем вы сможете запустить этот пример, вам нужно сконфигурировать кластер и загрузить свои данные в облако. В MATLAB можно создавать кластеры в облаке непосредственно с рабочего стола MATLAB. На вкладке «Вкладке Home», в меню Parallel, выберите Create and Manage Clusters. В Диспетчере профилей кластеров щелкните Создать облако. Также можно использовать MathWorks Cloud Center для создания и доступа к вычислительным кластерам. Дополнительные сведения см. в разделе Начало работы с облачным центром. В данном примере убедитесь, что кластер установлен по умолчанию на вкладке MATLAB Home, в Parallel > Select a Default Cluster. После этого загрузите свои данные в блок S3 Amazon и используйте их непосредственно из MATLAB. Этот пример использует копию CIFAR-10 набора данных, который уже хранится в Amazon S3. Инструкции см. в разделе Загрузка данных глубокого обучения в облако (Deep Learning Toolbox).

Загрузка набора данных из облака

Загрузите наборы обучающих и тестовых данных из облака с помощью imageDatastore. Разделите набор обучающих данных на наборы обучения и валидации и сохраните набор тестовых данных, чтобы протестировать лучшую сеть от сдвига параметра. В этом примере вы используете копию CIFAR-10 набора данных, хранящегося в Amazon S3. Чтобы убедиться, что работники имеют доступ к datastore в облаке, убедитесь, что переменные окружения для учетных данных AWS заданы правильно. Смотрите Загрузку данных глубокого обучения в облако (Deep Learning Toolbox).

imds = imageDatastore('s3://cifar10cloud/cifar10/train', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

Обучите сеть с дополненными данными об изображениях путем создания augmentedImageDatastore объект. Используйте случайные переводы и горизонтальные отражения. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

Обучите несколько сетей одновременно

Определите опции обучения. Установите размер мини-пакета и масштабируйте начальную скорость обучения линейно в соответствии с размером мини-пакета. Установите частоту валидации так, чтобы trainNetwork проверяет сеть один раз в эпоху.

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size
    'Verbose',false, ... % Do not send command line output.
    'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
    'L2Regularization',1e-10, ...
    'MaxEpochs',30, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency', validationFrequency);

Задайте глубины для сетевой архитектуры, на которых необходимо выполнить сдвиг параметра. Выполните параллельное тестирование параметров нескольких сетей одновременно с помощью parfeval. Используйте цикл, чтобы выполнить итерацию через различные сетевые архитектуры в сдвиге. Создайте вспомогательную функцию createNetworkArchitecture в конце скрипта, который принимает входной параметр, чтобы контролировать глубину сети и создает архитектуру для CIFAR-10. Использование parfeval для разгрузки расчетов, выполненных trainNetwork работнику в кластере. parfeval возвращает будущую переменную для хранения обученных сетей и обучающей информации при выполнении расчетов.

netDepths = 1:4;
for idx = 1:numel(netDepths)
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to the parallel pool (number of workers: 4).

parfeval не блокирует MATLAB, что означает, что можно продолжить выполнение команд. В этом случае получайте обученные сети и их обучающую информацию при помощи fetchOutputs на networksFuture. The fetchOutputs функция ожидает, пока будущие переменные не закончат.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Получите окончательные точности валидации сетей путем доступа к trainingInfo структура.

accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4

   72.5600   77.2600   79.4000   78.6800

Выберите лучшую сеть с точки зрения точности. Проверяйте его эффективность на соответствие тестовых данных набору.

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));
YPredicted = classify(bestNetwork,imdsTest);
accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7840

Вычислите матрицу неточностей для тестовых данных.

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');

Отправка данных обратной связи во время обучения

Подготовка и инициализация графиков, показывающих процесс обучения каждого из работников. Использование animatedLine для удобного способа показать меняющиеся данные.

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel('Iteration');
    ylabel('Training accuracy');
    lines(i) = animatedline;
end

Отправка данных о процессе обучения от работников клиенту при помощи DataQueue, а затем постройте график данных. Обновляйте графики каждый раз, когда работники отправляют обратную связь о процессе обучения при помощи afterEach. Значение параметра opts содержит информацию о работнике, итерации обучения и точности обучения.

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));

Задайте глубины для сетевой архитектуры, на которых можно выполнить сдвиг параметра, и выполните сдвиг параллельного параметра с помощью parfeval. Разрешить работникам доступ к любой вспомогательной функции в этом скрипте путем добавления скрипта к текущему пулу в качестве вложенного файла. Определите выходную функцию в опциях обучения, чтобы отправить процесс обучения от работников клиенту. Опции обучения зависят от индекса работника и должны быть включены в for цикл.

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions('sgdm', ...
        'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        'Verbose',false, ... % Do not send command line output.
        'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
        'L2Regularization',1e-10, ...
        'MaxEpochs',30, ...
        'Shuffle','every-epoch', ...
        'ValidationData',imdsValidation, ...
        'ValidationFrequency', validationFrequency);
    
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end

parfeval вызывает trainNetwork на рабочем месте в кластере. Расчеты происходят на фоне, поэтому можно продолжить работу в MATLAB. Если вы хотите остановить parfeval расчеты, можно вызвать cancel о соответствующей переменной будущего. Для примера, если вы наблюдаете, что сеть не работает, можно отменить ее будущее. Когда вы это делаете, следующая поставленная в очередь будущая переменная начинает свои расчеты.

В этом случае извлеките обученные сети и их обучающую информацию путем вызова fetchOutputs о будущих переменных.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Получите окончательную точность валидации для каждой сети.

accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4

   72.9200   77.4800   76.9200   77.0400

Вспомогательные функции

Задайте сетевую архитектуру для CIFAR-10 набора данных с функцией и используйте входной параметр, чтобы настроить глубину сети. Чтобы упростить код, используйте сверточные блоки, которые свертывают вход. Слои объединения понижают пространственные размерности.

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
    ];
end

Задайте функцию для создания сверточного блока в сетевой архитектуре.

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

Определите функцию для отправки процесса обучения клиенту через DataQueue.

function sendTrainingProgress(D,idx,info)
if info.State == "iteration"
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
end

Задайте функцию обновления, чтобы обновить графики, когда рабочий отправляет промежуточный результат.

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

См. также

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте