Обучите сеть Используя федеративное изучение

В этом примере показано, как обучить сеть с помощью объединенного в федерацию изучения. Федеративное изучение является методом, который позволяет вам обучить сеть распределенным, децентрализованным способом [1].

Федеративное изучение позволяет вам обучать модель с помощью данных из других источников, не перемещая данные в центральное расположение, даже если отдельные источники данных не совпадают с полным распределением набора данных. Это известно как зависимые и тождественно распределенные (non-IID) данные. Федеративное изучение может быть особенно полезным, когда обучающие данные являются большими, или когда существуют опасения конфиденциальности по поводу передачи обучающих данных.

Вместо распределительных данных федеративный метод изучения обучает многоуровневые модели, каждого в том же месте как источник данных. Можно создать глобальную модель, которая извлекла уроки из всех источников данных путем периодического сбора и объединения настраиваемых параметров локально обученных моделей. Таким образом можно обучить глобальную модель, централизованно не обрабатывая обучающих данных.

Этот пример использование объединил обучение в федерацию обучить модель классификации в параллели с помощью высоко non-IID набор данных. Модель обучена с помощью набора данных цифр, который состоит из 10 000 рукописных изображений чисел от 0 до 9. Пример запускается в параллели с помощью 10 рабочих, каждая обработка изображения одной цифры. Путем усреднения настраиваемых параметров сетей после каждого раунда обучения модели на каждом рабочем улучшают производительность через все классы, никогда не обрабатывая данные других классов.

В то время как конфиденциальность данных является одним из приложений федеративного изучения, этот пример не имеет дело с деталями поддержания конфиденциальности данных и безопасности. Этот пример демонстрирует основной федеративный алгоритм обучения.

Настройте параллельную среду

Создайте параллельный пул с тем же количеством рабочих как классы в наборе данных. В данном примере используйте локальный параллельный пул с 10 рабочими.

pool = parpool('local',10);
numWorkers = pool.NumWorkers;

Загрузите набор данных

Все данные, используемые в этом примере, первоначально хранимы в централизованном месте. Чтобы сделать эти данные высоко non-IID, необходимо распределить данные среди рабочих согласно классу. Чтобы создать валидацию и наборы тестовых данных, передайте фрагмент данных от рабочих клиенту. После того, как данные правильно настраиваются с обучающими данными отдельных классов на рабочих и тесте и данных о валидации всех классов на клиенте, нет никакой дальнейшей передачи данных во время обучения.

Задайте папку, содержащую данные изображения.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');

Распределите данные среди рабочих. Каждый рабочий получает изображения только одной цифры, такой, что рабочий 1 получает все изображения номера 0, рабочий 2 получает изображения номера 1 и т.д.

Изображения каждой цифры хранятся в разделять папке с именем той цифры. На каждом рабочем используйте fullfile функция, чтобы задать путь к определенной папке класса. Затем создайте imageDatastore это содержит все изображения той цифры. Затем используйте splitEachLabel функционируйте к случайным образом отдельным 30% данных для использования в валидации и тестировании. Наконец, создайте augmentedImageDatastore содержа обучающие данные.

inputSize = [28 28 1];
spmd   
    digitDatasetPath = fullfile(digitDatasetPath,num2str(labindex - 1));
    imds = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,...
        'LabelSource','foldernames');
    [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized");
    
    augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
end

Чтобы проверить производительность объединенной глобальной модели в течение и после обучения, создайте тест и наборы данных валидации, содержащие изображения от всех классов. Объедините тест и данные о валидации от каждого рабочего в один datastore. Затем разделите этот datastore в два хранилища данных, что каждый содержит 15% полных данных - один для того, чтобы проверить сеть во время обучения и другого для тестирования сети после обучения.

fileList = [];
labelList = [];

for i = 1:numWorkers
    tmp = imdsTestVal{i};
    
    fileList = cat(1,fileList,tmp.Files);
    labelList = cat(1,labelList,tmp.Labels);    
end

imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;

[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");

augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);

Данные теперь располагаются таким образом, что у каждого рабочего есть данные из единого класса, чтобы обучаться на, и клиент содержит валидацию и тестовые данные от всех классов.

Сеть Define

Определите количество классов в наборе данных.

classes = categories(imdsGlobalTest.Labels);
numClasses = numel(classes);

Определить сетевую архитектуру.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,32,'Name','conv1')
    reluLayer('Name','relu1')
    maxPooling2dLayer(2,'Name','maxpool1')
    convolution2dLayer(5,64,'Name','conv2')
    reluLayer('Name','relu2')
    maxPooling2dLayer(2,'Name','maxpool2')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];

Создайте объект dlnetwork из слоев.

dlnet = dlnetwork(layers)
dlnet = 
  dlnetwork with properties:

         Layers: [9×1 nnet.cnn.layer.Layer]
    Connections: [8×2 table]
     Learnables: [6×3 table]
          State: [0×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}
    Initialized: 1

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в разделе Model Gradients Function этого примера, который берет dlnetwork возразите и мини-пакет входных данных с соответствующими метками, и возвращает градиенты потери относительно настраиваемых параметров в сети и соответствующей потери.

Задайте федеративное усреднение функции

Создайте функциональный federatedAveraging, перечисленный в раздел Federated Averaging Function этого примера, который берет настраиваемые параметры сетей на каждом рабочем и коэффициенте нормализации для каждого рабочего, и возвращает усредненные настраиваемые параметры через все сети. Используйте средние настраиваемые параметры, чтобы обновить глобальную сеть и сеть на каждом рабочем.

Задайте вычисляют функцию точности

Создайте функциональный computeAccuracy, перечисленный в разделе Compute Accuracy Function этого примера, который берет dlnetwork объект, набор данных в minibatchqueue объект и список классов, и возвращают точность предсказаний через все наблюдения в наборе данных.

Задайте опции обучения

Во время обучения рабочие периодически передают свои сетевые настраиваемые параметры клиенту, так, чтобы клиент мог обновить глобальную модель. Обучение разделено на раунды. В конце каждого раунда обучения усреднены настраиваемые параметры, и глобальная модель обновляется. Модели рабочего затем заменяются новой глобальной моделью, и обучение продвигается рабочие.

Обучайтесь для 300 раундов с 5 эпохами на раунд. Обучение маленькому numer эпох на раунд гарантирует, что сети на рабочих не отличаются слишком далеко, прежде чем они будут усреднены.

numRounds = 300;
numEpochsperRound = 5;
miniBatchSize = 100;

Задайте опции для оптимизации SGD. Укажите, что начальная буква изучает уровень 0,001 и импульс 0.

learnRate = 0.001;
momentum = 0;

Обучите модель

Создайте указатель на функцию к пользовательскому мини-пакету, предварительно обрабатывающему функциональный preprocessMiniBatch (заданный в разделе Mini-Batch Preprocessing Function этого примера).

На каждом рабочем найдите общее количество учебных наблюдений обработанным локально на том рабочем. Используйте этот номер, чтобы нормировать настраиваемые параметры на каждом рабочем, когда вы найдете средние настраиваемые параметры после каждой коммуникации вокруг. Это помогает сбалансировать среднее значение, если существует различие между объемом данных по каждому рабочему.

На каждом рабочем создайте minibatchqueue возразите, что процессы и управляют мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Предварительно обработайте данные с помощью пользовательского мини-пакета, предварительно обрабатывающего функциональный preprocessMiniBatch преобразовывать метки в одногорячие закодированные переменные.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат в метки класса.

  • Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

preProcess = @(x,y)preprocessMiniBatch(x,y,classes);

spmd
    sizeOfLocalDataset = augimdsTrain.NumObservations;
    
    mbq = minibatchqueue(augimdsTrain,...
        'MiniBatchSize',miniBatchSize,...
        'MiniBatchFcn',preProcess,...
        'MiniBatchFormat',{'SSCB',''});
end

Создайте minibatchqueue объект, который управляет данными о валидации, чтобы использовать во время обучения. Используйте те же настройки в качестве minibatchqueue на каждом рабочем.

mbqGlobalVal = minibatchqueue(augimdsGlobalVal,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график процесса обучения.

figure
lineAccuracyTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Communication rounds")
ylabel("Accuracy (Global)")
grid on

Инициализируйте скоростной параметр для решателя SGDM.

velocity = [];

Инициализируйте глобальную модель. Чтобы запуститься, глобальная модель имеет те же начальные параметры как нетренированная сеть на каждом рабочем.

globalModel = dlnet;

Обучите модель с помощью пользовательского учебного цикла. Для каждой коммуникации вокруг,

  • Обновите сети на рабочих с последней глобальной сетью.

  • Обучите сети на рабочих в течение пяти эпох.

  • Найдите средние параметры всех сетей с помощью federatedAveraging функция.

  • Замените глобальные сетевые параметры на среднее значение.

  • Вычислите точность обновленной глобальной сети с помощью данных о валидации.

  • Отобразите прогресс обучения.

В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функции.

  • Обновите локальные сетевые параметры с помощью sgdmupdate функция.

start = tic;
for rounds = 1:numRounds
   
    spmd
        % Send global updated parameters to each worker.
        dlnet.Learnables.Value = globalModel.Learnables.Value;        
        
        % Loop over epochs.
        for epoch = 1:numEpochsperRound
            % Shuffle data.
            shuffle(mbq);
            
            % Loop over mini-batches.
            while hasdata(mbq)
                
                % Read mini-batch of data.
                [dlX,dlT] = next(mbq);
                
                % Evaluate the model gradients, state, and loss using dlfeval and the
                % modelGradients function and update the network state.
                [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT);
                
                % Update the network parameters using the SGDM optimizer.
                [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
                
            end
        end
        
        % Collect updated learnable parameters on each worker.
        workerLearnables = dlnet.Learnables.Value;
    end
    
    % Find normalization factors for each worker based on ratio of data
    % processed on that worker. 
    sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]);
    normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets;
    
    % Update the global model with new learnable parameters, normalized and
    % averaged across all workers.
    globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor);
    
    % Calculate the accuracy of the global model.
    accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes);
    
    % Display the training progress of the global model.
    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    addpoints(lineAccuracyTrain,rounds,double(accuracy))
    title("Communication round: " + rounds + ", Elapsed: " + string(D))
    drawnow
end

После финального раунда обучения обновите сеть на каждом рабочем с итоговыми средними настраиваемыми параметрами. Это важно, если вы хотите продолжить использовать или обучать сеть на рабочих.

spmd
    dlnet.Learnables.Value = globalModel.Learnables.Value;
end

Тестовая модель

Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками.

Создайте minibatchqueue объект, который управляет тестовыми данными. Используйте те же настройки в качестве minibatchqueue объекты используются во время обучения и валидации.

mbqGlobalTest = minibatchqueue(augimdsGlobalTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat','SSCB');

Используйте computePredictions функция, чтобы вычислить предсказанные классы и вычислить точность предсказаний через все тестовые данные.

accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
    0.9873

После того, как вы будете сделаны с вашими расчетами, можно удалить параллельный пул. gcp функция возвращает текущий параллельный объект пула, таким образом, можно удалить пул.

delete(gcp('nocreate'));

Функция градиентов модели

modelGradients функционируйте берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует T и возвращает градиенты потери относительно настраиваемых параметров в dlnet и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция. Чтобы вычислить предсказания сети во время обучения, используйте forward функция.

function [gradients,loss] = modelGradients(dlnet,dlX,T)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,T);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Вычислите функцию точности

computePredictions функционируйте берет dlnetwork объект dlnet, minibatchqueue объект mbq, и список классов, и возвращает точность всех предсказаний на обеспеченном наборе данных. Чтобы вычислить предсказания сети во время валидации или после, обучение закончено, используйте predict функция.

function accuracy = computeAccuracy(dlnet,mbq,classes)

    correctPredictions = [];
    
    shuffle(mbq);
    while hasdata(mbq)
        
        [dlXTest,dlTTest] = next(mbq);
        
        TTest = onehotdecode(dlTTest,classes,1)';
        
        dlYPred = predict(dlnet,dlXTest);
        YPred = onehotdecode(dlYPred,classes,1)';
        
        correctPredictions = [correctPredictions; YPred == TTest];
    end
    
    predSum = sum(correctPredictions);
    accuracy = single(predSum./size(correctPredictions,1));

end

Функция предварительной обработки мини-пакета

Функция preprocessMiniBatch предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.

  2. Извлеките данные о метке из массивов входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения.

  3. Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.

function [X,Y] = preprocessMiniBatch(XCell,YCell,classes)

    % Concatenate.
    X = cat(4,XCell{1:end});
    
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{1:end});
    
    % One-hot encode labels.
    Y = onehotencode(Y,1,'ClassNames',classes);

end

Федеративное усреднение функции

Функциональный federatedAveraging функционируйте берет настраиваемые параметры сетей на каждом рабочем и коэффициенте нормализации для каждого рабочего, и возвращает усредненные настраиваемые параметры через все сети. Используйте средние настраиваемые параметры, чтобы обновить глобальную сеть и сеть на каждом рабочем.

function learnables = federatedAveraging(workerLearnables,normalizationFactor)

    numWorkers = size(normalizationFactor,2);
    
    % Initialize container for averaged learnables with same size as existing
    % learnables. Use learnables of first worker network as an example.
    exampleLearnables = workerLearnables{1};
    learnables = cell(height(exampleLearnables),1);
    
    for i = 1:height(learnables)   
        learnables{i} = zeros(size(exampleLearnables{i}),'like',(exampleLearnables{i}));
    end
    
    % Add the normalized learnable parameters of all workers to
    % calculate average values.
    for i = 1:numWorkers
        tmp = workerLearnables{i};
        for values = 1:numel(learnables)
            learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values};
        end
    end
    
end

Ссылки

[1] Макмэхэн, Х. Брендан, Эйдер Мур, Дэниел Рамадж, Сет Хэмпсон и Блез Агуера y Arcas. "Эффективное коммуникацией Приобретение знаний о Глубоких Сетях из Децентрализованных Данных". Предварительно распечатайте, представленный. Февраль 2017. https://arxiv.org/abs/1602.05629.

Смотрите также

| | | | | |

Похожие темы