exponenta event banner

Обучение сети с использованием федеративного обучения

В этом примере показано, как обучать сеть с помощью федеративного обучения. Федеративное обучение - это метод, позволяющий обучать сеть распределенному децентрализованному способу [1].

Федеративное обучение позволяет обучать модель с использованием данных из различных источников без перемещения данных в центральное местоположение, даже если отдельные источники данных не соответствуют общему распределению набора данных. Это называется несамостоятельными и одинаково распределенными (не-IID) данными. Федеративное обучение может быть особенно полезным, когда данные обучения являются большими или когда существует проблема конфиденциальности при передаче данных обучения.

Вместо распределения данных метод федеративного обучения обучает несколько моделей, каждая из которых находится в том же месте, что и источник данных. Можно создать глобальную модель, извлеченную из всех источников данных, путем периодического сбора и объединения обучаемых параметров локально обученных моделей. Таким образом, можно обучить глобальную модель без централизованной обработки каких-либо данных обучения.

В этом примере используется федеративное обучение для параллельного обучения классификационной модели с использованием набора данных, в значительной степени отличного от IID. Модель обучается с использованием набора данных цифр, который состоит из 10000 рукописных изображений чисел от 0 до 9. Пример выполняется параллельно с использованием 10 работников, каждый из которых обрабатывает изображения одной цифры. Усредняя обучаемые параметры сетей после каждого цикла обучения, модели каждого работника улучшают производительность во всех классах, не обрабатывая данные других классов.

Хотя конфиденциальность данных является одним из приложений федеративного обучения, в этом примере не рассматриваются детали обеспечения конфиденциальности и безопасности данных. В этом примере показан базовый алгоритм федеративного обучения.

Настройка параллельной среды

Создайте параллельный пул с тем же количеством работников, что и классы в наборе данных. В этом примере используется локальный параллельный пул с 10 работниками.

pool = parpool('local',10);
numWorkers = pool.NumWorkers;

Загрузить набор данных

Все данные, используемые в этом примере, первоначально хранятся в централизованном местоположении. Чтобы сделать эти данные очень неидентифицированными, необходимо распределить данные между работниками в соответствии с классом. Чтобы создать наборы данных проверки и тестирования, передайте часть данных от работников клиенту. После правильной настройки данных, с данными обучения отдельных классов на рабочих и данными тестирования и валидации всех классов на клиенте, дальнейшая передача данных во время обучения не происходит.

Укажите папку, содержащую данные изображения.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');

Распределите данные между работниками. Каждый работник получает изображения только одной цифры, так что работник 1 получает все изображения числа 0, работник 2 получает изображения числа 1 и т.д.

Изображения каждой цифры хранятся в отдельной папке с именем этой цифры. На каждом работнике используйте fullfile для указания пути к определенной папке класса. Затем создайте imageDatastore содержит все изображения этой цифры. Далее используйте splitEachLabel функция случайного разделения 30% данных для использования при проверке и тестировании. Наконец, создайте augmentedImageDatastore содержащий данные обучения.

inputSize = [28 28 1];
spmd   
    digitDatasetPath = fullfile(digitDatasetPath,num2str(labindex - 1));
    imds = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,...
        'LabelSource','foldernames');
    [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized");
    
    augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
end

Чтобы проверить производительность объединенной глобальной модели во время и после обучения, создайте наборы данных тестирования и проверки, содержащие изображения из всех классов. Объединение данных тестирования и проверки от каждого работника в единое хранилище данных. Затем разделите это хранилище данных на два хранилища данных, каждое из которых содержит 15% общих данных - одно для проверки сети во время обучения, а другое для тестирования сети после обучения.

fileList = [];
labelList = [];

for i = 1:numWorkers
    tmp = imdsTestVal{i};
    
    fileList = cat(1,fileList,tmp.Files);
    labelList = cat(1,labelList,tmp.Labels);    
end

imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;

[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");

augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);

Данные теперь скомпонованы таким образом, что каждый работник имеет данные из одного класса для обучения, а клиент хранит данные проверки и тестирования из всех классов.

Определение сети

Определите количество классов в наборе данных.

classes = categories(imdsGlobalTest.Labels);
numClasses = numel(classes);

Определите архитектуру сети.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,32,'Name','conv1')
    reluLayer('Name','relu1')
    maxPooling2dLayer(2,'Name','maxpool1')
    convolution2dLayer(5,64,'Name','conv2')
    reluLayer('Name','relu2')
    maxPooling2dLayer(2,'Name','maxpool2')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];

lgraph = layerGraph(layers);

Создайте объект dlnetwork на основе графика слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [9×1 nnet.cnn.layer.Layer]
    Connections: [8×2 table]
     Learnables: [6×3 table]
          State: [0×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}
    Initialized: 1

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в разделе «Функция градиентов модели» этого примера, который принимает dlnetwork объект и мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.

Определение федеративной функции усреднения

Создание функции federatedAveraging, перечисленных в разделе «Федеративная усредняющая функция» этого примера, в котором берутся обучаемые параметры сетей для каждого работника и коэффициент нормализации для каждого работника и возвращаются усредненные усредненные обучаемые параметры по всем сетям. Используйте усредненные обучаемые параметры для обновления глобальной сети и сети каждого работника.

Определение функции точности вычислений

Создание функции computeAccuracy, перечисленных в разделе «Функция точности вычисления» этого примера, который принимает dlnetwork объект, набор данных внутри minibatchqueue и список классов, и возвращает точность прогнозов по всем наблюдениям в наборе данных.

Укажите параметры обучения

Во время обучения работники периодически передают свои сетевые обучаемые параметры клиенту, чтобы клиент мог обновить глобальную модель. Тренировка разделена на раунды. В конце каждого цикла обучения усредняются обучаемые параметры и обновляется глобальная модель. Затем модели работников заменяются новой глобальной моделью, и продолжается обучение работников.

Тренироваться на 300 раундов, по 5 эпох за раунд. Обучение для небольшого количества эпох за раунд гарантирует, что сети на рабочих не расходятся слишком далеко, прежде чем они усредняются.

numRounds = 300;
numEpochsperRound = 5;
miniBatchSize = 100;

Укажите параметры оптимизации SGD. Укажите начальную скорость обучения 0,001 и импульс 0.

learnRate = 0.001;
momentum = 0;

Модель поезда

Создание дескриптора функции для пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в разделе «Функция предварительной обработки мини-партий» данного примера).

Для каждого работника найдите общее количество наблюдений за обучением, обработанных локально для этого работника. Это число используется для нормализации обучаемых параметров каждого работника при нахождении средних обучаемых параметров после каждого раунда связи. Это помогает сбалансировать среднее значение при наличии разницы между объемом данных по каждому работнику.

Для каждого работника создайте minibatchqueue объект, обрабатывающий и управляющий мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Предварительная обработка данных с помощью пользовательской функции предварительной обработки мини-партии preprocessMiniBatch для преобразования меток в однокодированные переменные.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам класса.

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

preProcess = @(x,y)preprocessMiniBatch(x,y,classes);

spmd
    sizeOfLocalDataset = augimdsTrain.NumObservations;
    
    mbq = minibatchqueue(augimdsTrain,...
        'MiniBatchSize',miniBatchSize,...
        'MiniBatchFcn',preProcess,...
        'MiniBatchFormat',{'SSCB',''});
end

Создать minibatchqueue объект, который управляет данными проверки, используемыми во время обучения. Использовать те же настройки, что и minibatchqueue на каждом работнике.

mbqGlobalVal = minibatchqueue(augimdsGlobalVal,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график хода обучения.

figure
lineAccuracyTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Communication rounds")
ylabel("Accuracy (Global)")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Инициализируйте глобальную модель. Для запуска глобальная модель имеет те же начальные параметры, что и неподготовленная сеть для каждого работника.

globalModel = dlnet;

Обучение модели с помощью пользовательского цикла обучения. Для каждого раунда связи,

  • Обновление сетей работников с помощью последней глобальной сети.

  • Тренируйте сети на рабочих в течение пяти эпох.

  • Найдите средние параметры всех сетей с помощью federatedAveraging функция.

  • Замените параметры глобальной сети средним значением.

  • Вычислите точность обновленной глобальной сети с помощью данных проверки.

  • Просмотрите ход обучения.

Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функции.

  • Обновление параметров локальной сети с помощью sgdmupdate функция.

start = tic;
for rounds = 1:numRounds
   
    spmd
        % Send global updated parameters to each worker.
        dlnet.Learnables.Value = globalModel.Learnables.Value;        
        
        % Loop over epochs.
        for epoch = 1:numEpochsperRound
            % Shuffle data.
            shuffle(mbq);
            
            % Loop over mini-batches.
            while hasdata(mbq)
                
                % Read mini-batch of data.
                [dlX,dlT] = next(mbq);
                
                % Evaluate the model gradients, state, and loss using dlfeval and the
                % modelGradients function and update the network state.
                [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT);
                
                % Update the network parameters using the SGDM optimizer.
                [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
                
            end
        end
        
        % Collect updated learnable parameters on each worker.
        workerLearnables = dlnet.Learnables.Value;
    end
    
    % Find normalization factors for each worker based on ratio of data
    % processed on that worker. 
    sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]);
    normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets;
    
    % Update the global model with new learnable parameters, normalized and
    % averaged across all workers.
    globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor);
    
    % Calculate the accuracy of the global model.
    accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes);
    
    % Display the training progress of the global model.
    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    addpoints(lineAccuracyTrain,rounds,double(accuracy))
    title("Communication round: " + rounds + ", Elapsed: " + string(D))
    drawnow
end

После завершения последнего цикла обучения обновите сеть на каждом работнике с окончательными усредненными параметрами. Это важно, если вы хотите продолжать использовать или обучать сеть на рабочих.

spmd
    dlnet.Learnables.Value = globalModel.Learnables.Value;
end

Тестовая модель

Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками.

Создать minibatchqueue объект, управляющий тестовыми данными. Использовать те же настройки, что и minibatchqueue объекты, используемые во время обучения и проверки.

mbqGlobalTest = minibatchqueue(augimdsGlobalTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat','SSCB');

Используйте computePredictions функция для вычисления прогнозируемых классов и вычисления точности прогнозов по всем тестовым данным.

accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
    0.9873

После завершения вычислений можно удалить параллельный пул. gcp функция возвращает текущий объект параллельного пула, чтобы можно было удалить пул.

delete(gcp('nocreate'));

Функция градиентов модели

modelGradients функция принимает dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствующими метками T и возвращает градиенты потерь относительно обучаемых параметров в dlnet и потери. Для автоматического вычисления градиентов используйте dlgradient функция. Чтобы вычислить прогнозы сети во время обучения, используйте forward функция.

function [gradients,loss] = modelGradients(dlnet,dlX,T)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,T);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Функция точности вычислений

computePredictions функция принимает dlnetwork объект dlnet, a minibatchqueue объект mbqи список классов, и возвращает точность всех прогнозов по предоставленному набору данных. Чтобы вычислить прогнозы сети во время проверки или после завершения обучения, используйте predict функция.

function accuracy = computeAccuracy(dlnet,mbq,classes)

    predictions = [];
    correctPredictions = [];
    
    shuffle(mbq);
    while hasdata(mbq)
        
        [dlXTest,dlTTest] = next(mbq);
        
        TTest = onehotdecode(dlTTest,classes,1)';
        
        dlYPred = predict(dlnet,dlXTest);
        YPred = onehotdecode(dlYPred,classes,1)';
        
        predictions = [predictions; YPred];
        correctPredictions = [correctPredictions; YPred == TTest];
    end
    
    predSum = sum(correctPredictions);
    accuracy = single(predSum./size(correctPredictions,1));

end

Функция предварительной обработки мини-партий

Функция предварительной обработки MiniBatch выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в числовой массив. Конкатенация данных изображения над четвертым размером добавляет к каждому изображению третий размер, который будет использоваться в качестве размера одиночного канала.

  2. Извлеките данные метки из входящих массивов ячеек и объедините их в категориальный массив вдоль второго измерения.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,Y] = preprocessMiniBatch(XCell,YCell,classes)

    % Concatenate.
    X = cat(4,XCell{1:end});
    
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{1:end});
    
    % One-hot encode labels.
    Y = onehotencode(Y,1,'ClassNames',classes);

end

Федеративная усредняющая функция

Функция federatedAveraging функция берет обучаемые параметры сетей каждого работника и коэффициент нормализации для каждого работника и возвращает усредненные обучаемые параметры по всем сетям. Используйте усредненные обучаемые параметры для обновления глобальной сети и сети каждого работника.

function learnables = federatedAveraging(workerLearnables,normalizationFactor)

    numWorkers = size(normalizationFactor,2);
    
    % Initialize container for averaged learnables with same size as existing
    % learnables. Use learnables of first worker network as an example.
    exampleLearnables = workerLearnables{1};
    learnables = cell(height(exampleLearnables),1);
    
    for i = 1:height(learnables)   
        learnables{i} = zeros(size(exampleLearnables{i}),'like',(exampleLearnables{i}));
    end
    
    % Add the normalized learnable parameters of all workers to
    % calculate average values.
    for i = 1:numWorkers
        tmp = workerLearnables{i};
        for values = 1:numel(learnables)
            learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values};
        end
    end
    
end

Ссылки

[1] МакМахан, Х. Брендан, Эйдер Мур, Дэниел Рамаж, Сет Хэмпсон и Блез Агуэра-и-Аркас. «Эффективное обучение глубоким сетям на основе децентрализованных данных». Препринт, отправлен. Февраль 2017 года. https://arxiv.org/abs/1602.05629.

См. также

| | | | | |

Связанные темы