Обучите сеть параллельно с пользовательским учебным циклом

В этом примере показано, как настроить пользовательский учебный цикл, чтобы обучить сеть параллельно. В этом примере параллельные рабочие обучаются на фрагментах полного мини-пакета. Если у вас есть графический процессор, то обучение происходит на графическом процессоре. Во время обучения, DataQueue объект передает информацию о процессе обучения обратно клиенту MATLAB.

Загрузите набор данных

Загрузите набор данных цифры и создайте datastore изображений для набора данных. Разделите datastore в обучение и протестируйте хранилища данных рандомизированным способом. Создайте augmentedImageDatastore содержа обучающие данные.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

inputSize = [28 28 1];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);

Определите различные классы в наборе обучающих данных.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Сеть Define

Задайте свою сетевую архитектуру и превратите ее в график слоев при помощи layerGraph функция. Эта сетевая архитектура включает слои нормализации партии., которые отслеживают среднее значение и статистику отклонения набора данных. Когда обучение параллельно, объедините статистику от всех рабочих в конце каждого шага итерации, чтобы гарантировать, что сетевое состояние отражает целый мини-пакет. В противном случае сетевое состояние может отличаться через рабочих. Если вы - учебные рекуррентные нейронные сети с сохранением информации (RNNs), например, с помощью данных о последовательности, которые были разделены в меньшие последовательности, чтобы обучить нейронные сети содержащий LSTM или слои ГРУ, необходимо также управлять состоянием между рабочими.

layers = [
    imageInputLayer([28 28 1],'Name','input','Normalization','none')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];

lgraph = layerGraph(layers);

Создайте dlnetwork объект от графика слоев. dlnetwork объекты допускают обучение с пользовательскими циклами.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [11×1 nnet.cnn.layer.Layer]
    Connections: [10×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'fc'}
    Initialized: 1

Настройте параллельную среду

Определите, доступны ли графические процессоры для MATLAB, чтобы использовать с canUseGPU функция.

  • Если существуют доступные графические процессоры, то обучаются на графических процессорах. Создайте параллельный пул со столькими же рабочих сколько графические процессоры.

  • Если нет никаких доступных графических процессоров, то не обучаются на центральных процессорах. Создайте параллельный пул с количеством по умолчанию рабочих.

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount("available");
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 4).

Получите количество рабочих в параллельном пуле. Позже в этом примере, вы делите рабочую нагрузку согласно этому номеру.

N = pool.NumWorkers;

Обучите модель

Задайте опции обучения.

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

Для обучения графического процессора методические рекомендации должны увеличить мини-пакетный размер линейно с количеством графических процессоров, для того, чтобы сохранить рабочую нагрузку на каждом графическом процессоре постоянной. Для более связанного совета смотрите Глубокое обучение для MATLAB на Нескольких графических процессорах.

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N
end
miniBatchSize = 512

Вычислите мини-пакетный размер для каждого рабочего путем деления полного мини-пакетного размера равномерно между рабочими. Распределите остаток на первых рабочих.

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
workerMiniBatchSize = 1×4

   128   128   128   128

Эта сеть содержит слои нормализации партии., которые отслеживают среднее значение и отклонение данных, на которых обучена сеть. Начиная с каждого рабочие процессы фрагмент каждого мини-пакета во время каждой итерации, среднего значения и отклонения должен быть агрегирован через всех рабочих. Найдите индексы среднего значения и параметры состояния отклонения слоев нормализации партии. в сетевом свойстве состояний.

batchNormLayers = arrayfun(@(l)isa(l,'nnet.cnn.layer.BatchNormalizationLayer'),dlnet.Layers);
batchNormLayersNames = string({dlnet.Layers(batchNormLayers).Name});
state = dlnet.State;
isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean";
isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Чтобы передать данные обратно от рабочих во время обучения, создайте DataQueue объект. Используйте afterEach настраивать функцию, displayTrainingProgress, чтобы вызвать каждый раз, рабочий отправляет данные. displayTrainingProgress функция поддержки, заданная в конце этого примера, который отображает информацию о процессе обучения, которая прибывает от рабочих.

Q = parallel.pool.DataQueue;
displayFcn = @(x) displayTrainingProgress(x,lineLossTrain);
afterEach(Q,displayFcn);

Обучите модель с помощью пользовательского параллельного учебного цикла, как детализировано в следующих шагах. Чтобы выполнить код одновременно на всех рабочих, используйте spmd блок. В spmd блок, labindex дает индекс выполняющегося в данного момента рабочего код.

Перед обучением разделите datastore для каждого рабочего при помощи partition функция. Используйте разделенный datastore, чтобы создать minibatchqueue на каждом рабочем. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы нормировать данные, преобразуйте метки в одногорячие закодированные переменные и определите количество наблюдений в мини-пакете.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат в метки класса или количество наблюдений.

  • Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox) (Parallel Computing Toolbox).

В течение каждой эпохи, сброса и перестановки datastore с reset и shuffle функции. Для каждой итерации в эпоху:

  • Убедитесь, что у всех рабочих есть доступные данные прежде, чем начать обрабатывать его параллельно путем выполнения глобального and операция (gop) на результате hasdata функция.

  • Считайте мини-пакет из minibatchqueue при помощи next функция.

  • Вычислите градиенты и потерю сети на каждом рабочем путем вызова dlfeval на modelGradients функция. dlfeval функция выполняет функцию помощника modelGradients с автоматическим включенным дифференцированием, таким образом, modelGradients может вычислить градиенты относительно потери автоматическим способом. modelGradients задан в конце примера и возвращает потерю и градиенты, учитывая сеть, мини-пакет данных и истинные метки.

  • Чтобы получить полную потерю, агрегируйте потери на всех рабочих. Этим примером использование пересекает энтропию для функции потерь и агрегированную потерю, является сумма всех потерь. Перед агрегацией нормируйте каждую потерю путем умножения пропорцией полного мини-пакета, что рабочий продолжает работать. Используйте gplus добавить все потери вместе и реплицировать результаты через рабочих.

  • Чтобы агрегировать и обновить градиенты всех рабочих, используйте dlupdate функция с aggregateGradients функция. aggregateGradients функция, определяемая поддержки в конце этого примера. Эта функция использует gplus добавить вместе и реплицировать градиенты через рабочих, после нормализации согласно пропорции полного мини-пакета, что каждый рабочий продолжает работать.

  • Агрегируйте состояние сети на всех рабочих, использующих aggregateState функция. aggregateState функция, определяемая поддержки в конце этого примера. Слои нормализации партии. в сети отслеживают среднее значение и отклонение данных. Поскольку полный мини-пакет распространен через несколько рабочих, агрегируйте сетевое состояние после каждой итерации, чтобы вычислить среднее значение и отклонение целого мини-пакета.

  • После вычисления итоговых градиентов обновите сетевые настраиваемые параметры с sgdmupdate функция.

  • Передайте информацию о процессе обучения обратно клиенту при помощи send функция с DataQueue. Используйте только одного рабочего, чтобы отправить данные, потому что у всех рабочих есть та же информация о потере. Гарантировать, что данные находятся на центральном процессоре, так, чтобы клиентская машина без графического процессора могла получить доступ к нему, gather использования на dlarray прежде, чем отправить его.

start = tic;
spmd
    % Reset and shuffle the datastore.
    reset(augimdsTrain);
    augimdsTrain = shuffle(augimdsTrain);

    % Partition datastore.
    workerImds = partition(augimdsTrain,N,labindex);

    % Create minibatchqueue using partitioned datastore on each worker
    workerMbq = minibatchqueue(workerImds,3,...
        "MiniBatchSize",workerMiniBatchSize(labindex),...
        "MiniBatchFcn",@preprocessMiniBatch,...
        "MiniBatchFormat",{'SSCB','',''});

    workerVelocity = velocity;
   
    iteration = 0;
    
    for epoch = 1:numEpochs
        shuffle(workerMbq);
        
        % Loop over mini-batches.
        while gop(@and,hasdata(workerMbq))
            iteration = iteration + 1;
            
            % Read a mini-batch of data.
            [dlworkerX,workerY,workerNumObservations] = next(workerMbq);
            
            % Evaluate the model gradients and loss on the worker.
            [workerGradients,dlworkerLoss,workerState] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY);
            
            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
            loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss));
            
            % Aggregate the network state on all workers
            dlnet.State = aggregateState(workerState,workerNormalizationFactor,...
                isBatchNormalizationStateMean,isBatchNormalizationStateVariance);
            
            % Aggregate the gradients on all workers.
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            % Update the network parameters using the SGDM optimizer.
            [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity);
        end
        
       % Display training progress information.
       if labindex == 1
           data = [epoch loss iteration toc(start)];
           send(Q,gather(data)); 
       end
    end
end

Тестовая модель

После того, как вы обучите сеть, можно протестировать ее точность.

Загрузите тестовые изображения в память при помощи readall на тестовом datastore конкатенируйте их и нормируйте их.

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

После того, как обучение завершено, у всех рабочих есть завершенное то же самое, обучил сеть. Получите любого из них.

dlnetFinal = dlnet{1};

Классифицировать изображения с помощью dlnetwork объект, используйте predict функция на dlarray.

dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));

От предсказанных баллов найдите класс с самым высоким счетом с max функция. Прежде чем вы сделаете это, извлеките данные из dlarray с extractdata функция.

[~,idx] = max(extractdata(dlYPredScores),[],1);
YPred = classes(idx);

Чтобы получить точность классификации модели, сравните предсказания на наборе тестов против истинных меток.

accuracy = mean(YPred==YTest)
accuracy = 0.8960

Мини-функция предварительной обработки пакета

preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток с помощью следующих шагов:

  1. Определите количество наблюдений в мини-пакете

  2. Предварительно обработайте изображения с помощью preprocessMiniBatchPredictors функция.

  3. Извлеките данные о метке из массива входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения.

  4. Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.

function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell)

numObs = numel(YCell);

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

Мини-пакетные предикторы, предварительно обрабатывающие функцию

preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из входного массива ячеек, и конкатенируйте в числовой массив. Для полутонового входа, конкатенирующего по четвертой размерности, добавляет третью размерность в каждое изображение, чтобы использовать в качестве одноэлементной размерности канала. Данные затем нормированы.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

% Normalize.
X =  X ./ 255;

end

Функции градиентов модели

Задайте функцию, modelGradients, вычислить градиенты потери относительно настраиваемых параметров сети. Эта функция вычисляет сетевые выходные параметры для мини-пакетного X с forward и softmax и вычисляет потерю, учитывая истинные выходные параметры, с помощью перекрестной энтропии. Когда вы вызываете эту функцию с dlfeval, автоматическое дифференцирование включено, и dlgradient может вычислить градиенты потери относительно learnables автоматически.

function [dlgradients,dlloss,state] = modelGradients(dlnet,dlX,dlY)
    [dlYPred,state] = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    dlloss = crossentropy(dlYPred,dlY);
    dlgradients = dlgradient(dlloss,dlnet.Learnables);
end

Отобразите функцию процесса обучения

Задайте функцию, чтобы отобразить информацию о процессе обучения, которая прибывает от рабочих. DataQueue в этом примере вызывает эту функцию каждый раз, когда рабочий отправляет данные.

function displayTrainingProgress (data,line)
     addpoints(line,double(data(3)),double(data(2)))
     D = duration(0,0,data(4),'Format','hh:mm:ss');
     title("Epoch: " + data(1) + ", Elapsed: " + string(D))
     drawnow
end

Совокупная функция градиентов

Задайте функцию, которая агрегировала градиенты на всех рабочих путем добавления их вместе. gplus добавляет вместе и реплицирует все градиенты в рабочих. Прежде, чем добавить их вместе, нормируйте их путем умножения их на фактор, который представляет пропорцию полного мини-пакета, что рабочий продолжает работать. Получать содержимое dlarrayUse extractdata.

function gradients = aggregateGradients(dlgradients,factor)
    gradients = extractdata(dlgradients);
    gradients = gplus(factor*gradients);
end

Совокупная функция состояния

Задайте функцию, которая агрегировала сетевое состояние на всех рабочих. Сетевое состояние содержит обученную статистику нормализации партии. набора данных. Поскольку каждый рабочий только видит фрагмент мини-пакета, агрегируйте сетевое состояние так, чтобы статистические данные были представительными для статистики через все данные. Для каждого мини-пакета объединенное среднее значение вычисляется как взвешенное среднее среднего значения через рабочих для каждой итерации. Объединенное отклонение вычисляется согласно следующей формуле:

sc2=1Mj=1Nmj[sj2+(xj-xc)2]

где Nобщее количество рабочих, Mобщее количество наблюдений в мини-пакете, mj количество наблюдений, обработанных на jрабочий th, xj и sj2 среднее значение и статистика отклонения, вычисленная на того рабочего, и xc объединенное среднее значение через всех рабочих.

function state = aggregateState(state,factor,...
    isBatchNormalizationStateMean,isBatchNormalizationStateVariance)

    stateMeans = state.Value(isBatchNormalizationStateMean);
    stateVariances = state.Value(isBatchNormalizationStateVariance);

    for j = 1:numel(stateMeans)
        meanVal = stateMeans{j};
        varVal = stateVariances{j};
        
        % Calculate combined mean
        combinedMean = gplus(factor*meanVal);
               
        % Calculate combined variance terms to sum
        varTerm = factor.*(varVal + (meanVal - combinedMean).^2);        
        
        % Update state
        stateMeans{j} = combinedMean;
        stateVariances{j} = gplus(varTerm);
    end

    state.Value(isBatchNormalizationStateMean) = stateMeans;
    state.Value(isBatchNormalizationStateVariance) = stateVariances;
end

Смотрите также

| | | | | | | | |

Похожие темы