Используйте parfeval, чтобы обучить несколько нейронных сетей для глубокого обучения

Этот пример показывает, как использовать parfeval, чтобы выполнить развертку параметра на глубине сетевой архитектуры для нейронной сети для глубокого обучения и получить данные во время обучения.

Обучение глубокому обучению часто занимает часы или дни, и поиск хорошей архитектуры может быть трудным. С параллельными вычислениями можно убыстриться и автоматизировать поиск хороших моделей. Если у вас есть доступ к машине с несколькими графическими блоками обработки (графические процессоры), можно завершить этот пример на локальной копии набора данных с локальным параллельным пулом. Если вы хотите использовать больше ресурсов, можно увеличить обучение глубокому обучению к облаку. Этот пример показывает, как использовать parfeval, чтобы выполнить развертку параметра на глубине сетевой архитектуры в кластере в облаке. Используя parfeval позволяет вам обучаться в фоновом режиме, не блокируя MATLAB и предоставляет возможности останавливаться рано, если результаты являются удовлетворительными. Можно изменить скрипт, чтобы сделать развертку параметра на любом другом параметре. Кроме того, этот пример показывает, как получить обратную связь от рабочих во время вычисления при помощи DataQueue.

Требования

Прежде чем можно будет запустить этот пример, необходимо сконфигурировать кластер и загрузить данные на Облако. В MATLAB можно создать кластеры в облаке непосредственно с Рабочего стола MATLAB. На вкладке Home, в меню Parallel, выбирают Create и Manage Clusters. В Кластерном менеджере по Профилю нажмите Create Cloud Cluster. Также можно использовать MathWorks Cloud Center, чтобы создать и получить доступ, вычисляют кластеры. Для получения дополнительной информации смотрите Начало работы с Центром Облака. В данном примере гарантируйте, что ваш кластер установлен по умолчанию на вкладке MATLAB Home, параллельно> Выбирают Default Cluster. После этого загрузите свои данные на блок Amazon S3 и используйте их непосредственно из MATLAB. Этот пример использует копию набора данных CIFAR-10, который уже хранится в Amazon S3. Для инструкций смотрите Данные о Глубоком обучении Загрузки к Облаку (Deep Learning Toolbox).

Загрузите набор данных от облака

Загрузите обучение и наборы тестовых данных от облака с помощью imageDatastore. Разделите обучающий набор данных в наборы обучения и валидации и сохраните набор тестовых данных, чтобы протестировать лучшую сеть от развертки параметра. В этом примере вы используете копию набора данных CIFAR-10, сохраненного в Amazon S3. Чтобы гарантировать, что у рабочих есть доступ к datastore в облаке, убедитесь, что переменные окружения для учетных данных AWS установлены правильно. Смотрите Данные о Глубоком обучении Загрузки к Облаку (Deep Learning Toolbox).

imds = imageDatastore('s3://cifar10cloud/cifar10/train', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

Обучите сеть с увеличенными данными изображения путем создания объекта augmentedImageDatastore. Используйте случайные переводы и горизонтальные отражения. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

Обучите несколько сетей одновременно

Задайте опции обучения. Установите мини-пакетный размер и масштабируйте начальный темп обучения линейно согласно мини-пакетному размеру. Установите частоту валидации так, чтобы trainNetwork подтвердил сеть однажды в эпоху.

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size
    'Verbose',false, ... % Do not send command line output.
    'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
    'L2Regularization',1e-10, ...
    'MaxEpochs',30, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency', validationFrequency);

Задайте глубины для сетевой архитектуры, на которой можно сделать развертку параметра. Выполните параллельную развертку параметра, обучающую несколько сетей одновременно с помощью parfeval. Используйте цикл, чтобы выполнить итерации через различную сетевую архитектуру в развертке. Создайте функцию помощника createNetworkArchitecture в конце скрипта, который берет входной параметр, чтобы управлять глубиной сети и создает архитектуру для CIFAR-10. Используйте parfeval, чтобы разгрузить вычисления, выполняемые trainNetwork рабочему в кластере. parfeval возвращает будущую переменную, чтобы содержать обучившую нейронные сети и учебную информацию, когда вычисления сделаны.

netDepths = 1:4;
for idx = 1:numel(netDepths)
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to the parallel pool (number of workers: 4).

parfeval не делает блока MATLAB, что означает, что можно продолжить выполнять команды. В этом случае получите обучивший нейронные сети и их учебную информацию при помощи fetchOutputs на networksFuture. Функция fetchOutputs ожидает, пока будущие переменные не заканчиваются.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Получите итоговую точность валидации конкретной сети путем доступа к структуре trainingInfo. Например, получите точность первой сети.

accuracy = trainingInfo(1).ValidationAccuracy(end)
accuracy = 72.7600

Чтобы получить всю итоговую точность валидации, используйте cellfun.

accuracies = cellfun(@(x) x(end),{trainingInfo.ValidationAccuracy})
accuracies = 1×4

   72.7600   77.7000   77.5000   76.1200

Выберите лучшую сеть с точки зрения точности. Проверьте его производительность против набора тестовых данных.

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));
YPredicted = classify(bestNetwork,imdsTest);
accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7732

Вычислите матрицу беспорядка для тестовых данных.

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');

Отправьте данные об отклике во время обучения

Подготовьте и инициализируйте графики, которые показывают учебный прогресс каждого из рабочих. Используйте animatedLine для удобного способа показать изменение данных.

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel('Iteration');
    ylabel('Training accuracy');
    lines(i) = animatedline;
end

Отправьте учебные данные о прогрессе от рабочих клиенту при помощи DataQueue, и затем отобразите данные на графике. Обновите графики каждый раз, когда рабочие отправляют учебный отклик прогресса при помощи afterEach. Параметр opts содержит информацию о рабочем, учебной итерации и учебной точности.

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));

Задайте глубины для сетевой архитектуры, на которой можно сделать развертку параметра и выполнить параллельную развертку параметра с помощью parfeval. Позвольте рабочим получать доступ к любой функции помощника в этом скрипте путем добавления скрипта в текущий пул как прикрепленный файл. Задайте выходную функцию в опциях обучения, чтобы отправить учебный прогресс от рабочих клиенту. Опции обучения зависят от индекса рабочего и должны быть включены в цикле for.

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions('sgdm', ...
        'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        'Verbose',false, ... % Do not send command line output.
        'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
        'L2Regularization',1e-10, ...
        'MaxEpochs',30, ...
        'Shuffle','every-epoch', ...
        'ValidationData',imdsValidation, ...
        'ValidationFrequency', validationFrequency);
    
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end

parfeval вызывает trainNetwork на рабочего в кластере. Вычисления происходят на фоне, таким образом, можно продолжить работать в MATLAB. Если вы хотите остановить вычисление parfeval, можно вызвать cancel на его соответствующей будущей переменной. Например, если вы замечаете, что сеть показывает низкие результаты, можно отменить ее будущее. Когда вы делаете так, следующая будущая переменная с очередями запускает свои вычисления.

В этом случае выберите обучивший нейронные сети и их учебную информацию путем вызова fetchOutputs на будущие переменные.

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Получите итоговую точность валидации для каждой сети при помощи cellfun.

accuracies = cellfun(@(x) x(end),{trainingInfo.ValidationAccuracy})
accuracies = 1×4

   72.9200   77.4800   76.9200   77.0400

Функции помощника

Задайте сетевую архитектуру для набора данных CIFAR-10 с функцией и используйте входной параметр, чтобы настроить глубину сети. Чтобы упростить код, используйте сверточные блоки, которые применяют операцию свертки к входу. Слои объединения субдискретизируют пространственные размерности.

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
    ];
end

Задайте функцию, чтобы создать сверточный блок в сетевой архитектуре.

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

Задайте функцию, чтобы отправить учебный прогресс клиенту через DataQueue.

function sendTrainingProgress(D,idx,info)
if info.State == "iteration"
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
end

Задайте функцию обновления, чтобы обновить графики, когда рабочий отправит промежуточный результат.

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

Смотрите также

| | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте