Обучите сеть параллельно с пользовательским циклом обучения

В этом примере показано, как настроить пользовательский цикл обучения для параллельного обучения сети. В этом примере параллельные рабочие обучают на фрагментах общей мини-партии. Если у вас есть графический процессор, то обучение происходит на графический процессор. Во время обучения, DataQueue объект отправляет информацию о процессе обучения обратно клиенту MATLAB.

Загрузка набора данных

Загрузите набор данных цифр и создайте datastore для набора данных. Разделите datastore на обучающие и тестовые хранилища данных рандомизированным способом.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

Определите различные классы в набор обучающих данных.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Определение сети

Определите свою сетевую архитектуру и сделайте ее в график слоев при помощи layerGraph функция. Эта сетевая архитектура включает нормализации партии . уровня, которые отслеживают среднюю и отклонение статистику набора данных. При параллельном обучении объедините статистику от всех работников в конце каждого шага итерации, чтобы состояние сети отражало весь мини-пакет. В противном случае состояние сети может различаться между работниками. Если вы обучаете рекуррентные нейронные сети с учетом состояния (RNN), например, используя данные последовательности, которые были разделены на меньшие последовательности, для обучения сетей, содержащих слои LSTM или GRU, вы также должны управлять состоянием между работниками.

layers = [
    imageInputLayer([28 28 1],'Name','input','Normalization','none')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];

lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоев. dlnetwork объекты допускают обучение с помощью пользовательских циклов.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [11×1 nnet.cnn.layer.Layer]
    Connections: [10×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'fc'}
    Initialized: 1

Настройка параллельного окружения

Определите, доступны ли графические процессоры для использования MATLAB с canUseGPU функция.

  • Если доступны графические процессоры, обучите их на графических процессорах. Создайте параллельный пул с таким количеством рабочих процессов, как графические процессоры.

  • Если нет доступных графических процессоров, обучите их на центральных процессорах. Создайте параллельный пул с количеством работников по умолчанию.

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount("available");
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end

Получите количество работников в параллельном пуле. Позже в этом примере вы разделяете рабочую нагрузку согласно этому числу.

N = pool.NumWorkers;

Обучите модель

Задайте опции обучения.

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

Для обучения GPU рекомендуется линейно увеличить размер мини-пакета с количеством графических процессоров, порядок сохранить рабочую нагрузку на каждом графическом процессоре постоянной. Для получения дополнительной информации смотрите Обучение с несколькими графическими процессорами.

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N
end
miniBatchSize = 512

Рассчитать размер мини-пакета для каждого работника путем деления общего размера мини-пакета равномерно между работниками. Распределите оставшуюся часть по первым рабочим.

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
workerMiniBatchSize = 1×4

   128   128   128   128

Инициализируйте график процесса обучения.

% Set up the training plot
figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Чтобы отправить данные от работников во время обучения, создайте DataQueue объект. Использование afterEach чтобы настроить функцию, displayTrainingProgress, чтобы вызывать каждый раз, когда рабочий отправляет данные. displayTrainingProgress является вспомогательной функцией, определенной в конце этого примера, которая отображает информацию о процессе обучения, полученную от работников.

Q = parallel.pool.DataQueue;
displayFcn = @(x) displayTrainingProgress(x,lineLossTrain);
afterEach(Q,displayFcn);

Обучите модель с помощью пользовательского параллельного цикла обучения, как подробно описано на следующих шагах. Чтобы выполнить код одновременно на всех работниках, используйте spmd блок. В пределах spmd блок, labindex задает индекс работника выполняющегося в данного момента кода.

Перед обучением разделите datastore для каждого работника с помощью partition function, и задать ReadSize в мини-пакет размера рабочего.

Для каждой эпохи сбросьте и перетащите datastore с reset и shuffle функций. Для каждой итерации в эпоху:

  • Убедитесь, что все работники имеют данные, доступные до начала их параллельной обработки, путем выполнения глобального and операция (gop) о результате hasdata функция.

  • Считайте мини-пакет из datastore при помощи read и объедините извлеченные изображения в четырехмерный массив изображений. Нормализуйте изображения так, чтобы пиксели принимали значения между 0 и 1.

  • Преобразуйте метки в матрицу фиктивных переменных, которая помещает метки против наблюдений. Фиктивные переменные содержат 1 для метки наблюдения и 0 в противном случае.

  • Преобразуйте мини-пакет данных в dlarray объект с базовым типом одиночный и задает метки размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный). Для обучения графический процессор преобразуйте данные в gpuArray.

  • Вычислите градиенты и потери сети для каждого работника путем вызова dlfeval на modelGradients функция. The dlfeval функция оценивает вспомогательную функцию modelGradients с включенной автоматической дифференциацией, так modelGradients может вычислить градиенты относительно потерь автоматическим способом. modelGradients определяется в конце примера и возвращает потери и градиенты, заданные для сети, мини-пакета данных и истинных меток.

  • Для получения общих потерь суммируйте потери по всем работникам. Этот пример использует перекрестную энтропию для функции потерь, и агрегированные потери являются суммой всех потерь. Перед агрегированием нормализуйте каждую потерю путем умножения на долю от общего мини-пакета, над которым работает рабочий. Использование gplus чтобы добавить все потери вместе и повторить результаты между работниками.

  • Чтобы агрегировать и обновить градиенты всех работников, используйте dlupdate функция со aggregateGradients функция. aggregateGradients является вспомогательной функцией, заданной в конце этого примера. Эта функция использует gplus чтобы сложить вместе и реплицировать градиенты между работниками, после нормализации в соответствии с долей общего мини-пакета, над которым работает каждый работник.

  • Агрегируйте состояние сети для всех работников, использующих aggregateState функция. aggregateState является вспомогательной функцией, заданной в конце этого примера. Слои batchnormalization в сети отслеживают среднее значение и отклонение данных. Поскольку полный мини-пакет распределен по нескольким рабочим местам, агрегируйте состояние сети после каждой итерации, чтобы вычислить среднее значение и отклонение всего минибатча.

  • После вычисления окончательных градиентов обновите настраиваемые параметры сети с помощью sgdmupdate функция.

  • Отправка информации о процессе обучения обратно клиенту при помощи send функция со DataQueue. Для отправки данных используется только один рабочий, поскольку все рабочие имеют одинаковую информацию о потерях. Чтобы убедиться, что данные находятся на центральном процессоре, чтобы клиентская машина без графического процессора могла получить к нему доступ, используйте gather на dlarray перед отправкой.

start = tic;
spmd
    % Partition datastore.
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    
    workerVelocity = velocity;
   
    iteration = 0;
    
    for epoch = 1:numEpochs
        % Reset and shuffle the datastore.
        reset(workerImds);
        workerImds = shuffle(workerImds);
        
        % Loop over mini-batches.
        while gop(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            % Read a mini-batch of data.
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            % Normalize the images.
            workerXBatch =  single(workerXBatch) ./ 255;
            
            % Convert the labels to dummy variables.
            workerY = zeros(numClasses,workerNumObservations,'single');
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            % Convert the mini-batch of data to dlarray.
            dlworkerX = dlarray(workerXBatch,'SSCB');
            
            % If training on GPU, then convert data to gpuArray.
            if executionEnvironment == "gpu"
                dlworkerX = gpuArray(dlworkerX);
            end
            
            % Evaluate the model gradients and loss on the worker.
            [workerGradients,dlworkerLoss,workerState] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY);
            
            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
            loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss));
            
            % Aggregate the network state on all workers
            dlnet.State = aggregateState(workerState,workerNormalizationFactor);
            
            % Aggregate the gradients on all workers.
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            % Update the network parameters using the SGDM optimizer.
            [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity);
        end
        
       % Display training progress information.
       if labindex == 1
           data = [epoch loss iteration toc(start)];
           send(Q,gather(data)); 
       end
    end
end

Экспериментальная модель

После того, как вы обучите сеть, можно проверить ее точность.

Загрузите тестовые изображения в память при помощи readall на тестовом datastore, конкатенируйте их и нормализуйте их.

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

После завершения обучения все работники имеют одинаковую полную обученную сеть. Найдите любой из них.

dlnetFinal = dlnet{1};

Чтобы классифицировать изображения с помощью dlnetwork объект, используйте predict функция на dlarray.

dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));

Из предсказанных счетов найдите класс с самым высоким счетом со max функция. Прежде чем вы это сделаете, извлеките данные из dlarray с extractdata функция.

[~,idx] = max(extractdata(dlYPredScores),[],1);
YPred = classes(idx);

Чтобы получить классификационную точность модели, сравните предсказания на тестовом наборе с истинными метками.

accuracy = mean(YPred==YTest)
accuracy = 0.9910

Моделируйте функции градиентов

Задайте функцию, modelGradients, для вычисления градиентов потерь относительно настраиваемых параметров сети. Эта функция вычисляет выходы сети для мини-пакета X с forward и softmax и вычисляет потери, учитывая истинные выходы, используя перекрестную энтропию. Когда вы вызываете эту функцию с dlfeval, включена автоматическая дифференциация и dlgradient может автоматически вычислять градиенты потерь относительно обучаемых.

function [dlgradients,dlloss,state] = modelGradients(dlnet,dlX,dlY)
    [dlYPred,state] = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    dlloss = crossentropy(dlYPred,dlY);
    dlgradients = dlgradient(dlloss,dlnet.Learnables);
end

Отобразите функцию процесса обучения

Задайте функцию для отображения информации о процессе обучения, поступающей от работников. The DataQueue в этом примере эта функция вызывается каждый раз, когда работник отправляет данные.

function displayTrainingProgress (data,line)
     addpoints(line,double(data(3)),double(data(2)))
     D = duration(0,0,data(4),'Format','hh:mm:ss');
     title("Epoch: " + data(1) + ", Elapsed: " + string(D))
     drawnow
end

Совокупная функция градиентов

Задайте функцию, которая агрегирует градиенты для всех работников, добавляя их вместе. gplus складывает вместе и тиражирует все градиенты рабочих процессов. Прежде чем сложить их вместе, нормализуйте их, умножив их на множитель, который представляет долю общего мини-пакета, над которым работает рабочий. Получение содержимого dlarray, use extractdata.

function gradients = aggregateGradients(dlgradients,factor)
    gradients = extractdata(dlgradients);
    gradients = gplus(factor*gradients);
end

Совокупная функция состояния

Задайте функцию, которая агрегирует состояние сети для всех работников. Сетевое состояние содержит обученную статистику нормализации партии . набора данных. Поскольку каждый рабочий процесс видит только фрагмент мини-пакета, агрегируйте состояние сети так, чтобы статистика была показательной для статистики по всем данным. Для каждой мини-партии объединенное среднее значение вычисляется как средневзвешенное среднее значение для всех работников для каждой итерации. Комбинированное отклонение вычисляется по следующей формуле:

sc2=1Mj=1Nmj[sj2+(xj-xc)2]

где N- общее количество работников, M- общее количество наблюдений в мини-пакете, mj - количество наблюдений, обработанных на jрабочий, xj и sj2 являются статистическими данными о среднем значении и отклонениях, рассчитанными на этом работнике, и xc - комбинированное среднее для всех работников.

function state = aggregateState(state,factor)

    numrows = size(state,1);
    
    for j = 1:numrows
        isBatchNormalizationState = state.Parameter(j) =="TrainedMean"...
            && state.Parameter(j+1) =="TrainedVariance"...
            && state.Layer(j) == state.Layer(j+1);
        
        if isBatchNormalizationState
            meanVal = state.Value{j};
            varVal = state.Value{j+1};
            
            % Calculate combined mean
            combinedMean = gplus(factor*meanVal);
                   
            % Caclulate combined variance terms to sum
            combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);        
            
            % Update state
            state.Value(j) = {combinedMean};
            state.Value(j+1) = {gplus(combinedVarTerm)};
           
        end
    end
end

См. также

| | | | | | | | |

Похожие темы