Обучите сеть параллельно с пользовательским учебным циклом

В этом примере показано, как настроить пользовательский учебный цикл, чтобы обучить сеть параллельно. В этом примере параллельные рабочие обучаются на фрагментах полного мини-пакета. Если у вас есть графический процессор, то обучение происходит на графическом процессоре. Во время обучения, DataQueue объект передает информацию о процессе обучения обратно клиенту MATLAB.

Загрузите набор данных

Загрузите набор данных цифры и создайте datastore изображений для набора данных. Разделите datastore в обучение и протестируйте хранилища данных рандомизированным способом.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

Определите различные классы в наборе обучающих данных.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Сеть Define

Задайте свою сетевую архитектуру и превратите ее в график слоя при помощи layerGraph функция.

layers = [
    imageInputLayer([28 28 1],'Name','input','Normalization','none')
    convolution2dLayer(5,20,'Name','conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];

lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоя. dlnetwork объекты допускают обучение с пользовательскими циклами.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [8×1 nnet.cnn.layer.Layer]
    Connections: [7×2 table]
     Learnables: [8×3 table]
          State: [0×0 table]

Настройте параллельную среду

Определите, доступны ли графические процессоры для MATLAB, чтобы использовать с canUseGPU функция.

  • Если существуют доступные графические процессоры, то обучаются на графических процессорах. Создайте параллельный пул со столькими же рабочих сколько графические процессоры.

  • Если нет никаких доступных графических процессоров, то не обучаются на центральных процессорах. Создайте параллельный пул с количеством по умолчанию рабочих.

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount;
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 6).

Получите количество рабочих в параллельном пуле. Позже в этом примере, вы делите рабочую нагрузку согласно этому номеру.

N = pool.NumWorkers;

Чтобы передать данные обратно от рабочих во время обучения, создайте DataQueue объект. Используйте afterEach настраивать функцию, displayTrainingProgress, чтобы вызвать каждый раз, рабочий отправляет данные. displayTrainingProgress функция поддержки, заданная в конце этого примера, который отображает информацию о процессе обучения, которая прибывает от рабочих.

Q = parallel.pool.DataQueue;
afterEach(Q,@displayTrainingProgress);

Обучите модель

Задайте опции обучения.

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

Для обучения графического процессора методические рекомендации должны увеличить мини-пакетный размер линейно с количеством графических процессоров, для того, чтобы сохранить рабочую нагрузку на каждом графическом процессоре постоянной. Для более связанного совета смотрите Обучение с помощью Нескольких графических процессоров.

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N
end

Вычислите мини-пакетный размер для каждого рабочего путем деления полного мини-пакетного размера равномерно между рабочими. Распределите остаток на первых рабочих.

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
workerMiniBatchSize = 1×6

    22    22    21    21    21    21

Обучите модель с помощью пользовательского параллельного учебного цикла, как детализировано в следующих шагах. Чтобы выполнить код одновременно на всех рабочих, используйте spmd блок. В spmd блок, labindex дает индекс выполняющегося в данного момента рабочего код.

Перед обучением разделите datastore для каждого рабочего при помощи partition функция и набор ReadSize к мини-пакетному размеру рабочего.

В течение каждой эпохи, сброса и перестановки datastore с reset и shuffle функции.

Для каждой итерации в эпоху:

  • Убедитесь, что у всех рабочих есть доступные данные прежде, чем начать обрабатывать его параллельно путем выполнения глобального and операция (gop) на результате hasdata функция.

  • Считайте мини-пакет из datastore при помощи read функция, и конкатенирует полученные изображения в четырехмерный массив изображений. Нормируйте изображения так, чтобы пиксели приняли значения между 0 и 1.

  • Преобразуйте метки в матрицу фиктивных переменных, которая помещает метки против наблюдений. Фиктивные переменные содержат 1 для метки наблюдения и 0 в противном случае.

  • Преобразуйте мини-пакет данных к dlarray объект с базовым одним типом и указывает, что размерность маркирует 'SSCB' (пространственный, пространственный, канал, пакет). Для обучения графического процессора преобразуйте данные в gpuArray.

  • Вычислите градиенты и потерю сети на каждом рабочем путем вызова dlfeval на modelGradients функция. dlfeval функция выполняет функцию помощника modelGradients с автоматическим включенным дифференцированием, таким образом, modelGradients может вычислить градиенты относительно потери автоматическим способом. modelGradients задан в конце примера и возвращает потерю и градиенты, учитывая сеть, мини-пакет данных и истинные метки.

  • Чтобы получить полную потерю, агрегируйте потери на всех рабочих. Этим примером использование пересекает энтропию для функции потерь и агрегированную потерю, является сумма всех потерь. Перед агрегацией нормируйте каждую потерю путем умножения пропорцией полного мини-пакета, что рабочий продолжает работать. Используйте gplus добавить все потери вместе и реплицировать результаты через рабочих.

  • Чтобы агрегировать и обновить градиенты всех рабочих, используйте dlupdate функция на aggregateGradients функция. aggregateGradients функция, определяемая поддержки в конце примера. Эта функция использует gplus добавить вместе и реплицировать градиенты через рабочих, после нормализации согласно пропорции полного мини-пакета, что каждый рабочий продолжает работать.

  • После вычисления итоговых градиентов обновите сетевые learnable параметры с sgdmupdate функция.

  • Передайте информацию о процессе обучения обратно клиенту при помощи send функция с DataQueue. Используйте только одного рабочего, чтобы отправить данные, потому что у всех рабочих есть та же информация о потере. Гарантировать, что данные находятся на центральном процессоре, так, чтобы клиентская машина без графического процессора могла получить доступ к нему, gather использования на dlarray прежде, чем отправить его.

spmd
    % Partition datastore.
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    
    workerVelocity = velocity;
   
    iteration = 0;
    
    for epoch = 1:numEpochs
        % Reset and shuffle the datastore.
        reset(workerImds);
        workerImds = shuffle(workerImds);
        
        % Loop over mini-batches.
        while gop(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            % Read a mini-batch of data.
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            % Normalize the images.
            workerXBatch =  single(workerXBatch) ./ 255;
            
            % Convert the labels to dummy variables.
            workerY = zeros(numClasses,workerNumObservations,'single');
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            % Convert the mini-batch of data to dlarray.
            dlworkerX = dlarray(workerXBatch,'SSCB');
            
            % If training on GPU, then convert data to gpuArray.
            if executionEnvironment == "gpu"
                dlworkerX = gpuArray(dlworkerX);
            end
            
            % Evaluate the model gradients and loss on the worker.
            [workerGradients,dlworkerLoss] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY);
            
            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
            loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss));
            
            % Aggregate the gradients on all workers.
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            % Update the network parameters using the SGDM optimizer.
            [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity);
        end
        
       % Display training progress information.
       if labindex == 1
           data = [epoch loss];
           send(Q,gather(data)); 
       end
    end
end
Analyzing and transferring files to the workers ...done.

Тестовая модель

После того, как вы обучите сеть, можно протестировать ее точность.

Загрузите тестовые изображения в память при помощи readall на тестовом datastore конкатенируйте их и нормируйте их.

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

После того, как обучение завершено, у всех рабочих есть завершенное то же самое, обучил сеть. Получите любого из них.

dlnetFinal = dlnet{1};

Классифицировать изображения с помощью dlnetwork объект, используйте predict функция на dlarray.

dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));

От предсказанных баллов найдите класс с самым высоким счетом с max функция. Прежде чем вы сделаете это, извлеките данные из dlarray с extractdata функция.

[~,idx] = max(extractdata(dlYPredScores),[],1);
YPred = classes(idx);

Чтобы получить точность классификации модели, сравните прогнозы на наборе тестов против истинных меток.

accuracy = mean(YPred==YTest)
accuracy = 0.9970

Определение функций помощника

Задайте функцию, modelGradients, вычислить градиенты потери относительно learnable параметров сети. Эта функция вычисляет сетевые выходные параметры для мини-пакетного X с forward и softmax и вычисляет потерю, учитывая истинные выходные параметры, с помощью перекрестной энтропии. Когда вы вызываете эту функцию с dlfeval, автоматическое дифференцирование включено, и dlgradient может вычислить градиенты потери относительно learnables автоматически.

function [dlgradients,dlloss] = modelGradients(dlnet,dlX,dlY)
dlYPred = forward(dlnet,dlX);
dlYPred = softmax(dlYPred);

dlloss = crossentropy(dlYPred,dlY);
dlgradients = dlgradient(dlloss,dlnet.Learnables);
end

Задайте функцию, чтобы отобразить информацию о процессе обучения, которая прибывает от рабочих. DataQueue в этом примере вызывает эту функцию каждый раз, когда рабочий отправляет данные.

function displayTrainingProgress (data)
disp("Epoch: " + data(1) + ", Loss: " + data(2));
end

Задайте функцию, которая агрегировала градиенты на всех рабочих путем добавления их вместе. gplus добавляет вместе и реплицирует все градиенты в рабочих. Прежде, чем добавить их вместе, нормируйте их путем умножения их на фактор, который представляет пропорцию полного мини-пакета, что рабочий продолжает работать. Получать содержимое dlarrayUse extractdata. Передача класса содержимого как второй входной параметр к gplus убеждается, если данными является gpuArray, та быстрая коммуникация включена где это возможно.

function gradients = aggregateGradients(dlgradients,factor)
gradients = extractdata(dlgradients);
gradients = gplus(factor*gradients,class(gradients));
end

Смотрите также

| | | | | | | | |

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте