Обучите сеть с помощью федеративного обучения

В этом примере показано, как обучить сеть с помощью федеративного обучения. Федеративное обучение является методом, который позволяет вам обучать сеть распределенным, децентрализованным способом [1].

Федеративное обучение позволяет вам обучать модель с помощью данных из разных источников, не перемещая данные в центральное место, даже если отдельные источники данных не совпадают с общим распределением набора данных. Это известно как несамостоятельные и идентично распределенные (не-IID) данные. Федеративное обучение может быть особенно полезным, когда обучающие данные большие, или когда существуют проблемы конфиденциальности при передаче обучающих данных.

Вместо распределения данных метод федеративного обучения обучает несколько моделей, каждая из которых находится в том же месте, что и источник данных. Можно создать глобальную модель, которая извлечена из всех источников данных путем периодического сбора и объединения настраиваемых параметров локально обученных моделей. Таким образом, можно обучить глобальную модель, не обрабатывая централизованно какие-либо обучающие данные.

Этот пример использует федеративное обучение, чтобы параллельно обучить классификационную модель с использованием набора данных, не являющегося IID. Модель обучается с помощью набора данных цифр, который состоит из 10000 рукописных изображений цифр от 0 до 9. Пример запускается параллельно с использованием 10 рабочих мест, каждый обрабатывает изображения с одной цифрой. Путем усреднения настраиваемых параметров сетей после каждого цикла обучения, модели на каждом работнике улучшают эффективность во всех классах, никогда не обрабатывая данные других классов.

Хотя конфиденциальность данных является одним из приложений федеративного обучения, этот пример не касается деталей поддержания конфиденциальности и безопасности данных. Этот пример демонстрирует основной алгоритм федеративного обучения.

Настройка параллельного окружения

Создайте параллельный пул с таким же количеством работников, как и классы в наборе данных. В данном примере используйте локальный параллельный пул с 10 рабочими.

pool = parpool('local',10);
numWorkers = pool.NumWorkers;

Загрузка набора данных

Все данные, используемые в этом примере, первоначально хранятся в централизованном местоположении. Чтобы сделать эти данные сильно отличными от IID, необходимо распределить данные между работниками в соответствии с классом. Чтобы создать наборы валидации и тестовых данных, передайте фрагментов данных от работников клиенту. После правильной настройки данных, с обучающими данными отдельных классов по работникам и тестовыми и валидационными данными всех классов на клиенте, дальнейшая передача данных во время обучения отсутствует.

Укажите папку, содержащую данные изображения.

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',...
    'nndatasets','DigitDataset');

Распределите данные между работниками. Каждый рабочий получает изображения только одной цифры, так что рабочий 1 получает все изображения числа 0, рабочий 2 получает изображения числа 1 и т.д.

Изображения каждой цифры хранятся в отдельной папке с именем этой цифры. На каждом работнике используйте fullfile функция для задания пути к определенной папке класса. Затем создайте imageDatastore который содержит все изображения этой цифры. Далее используйте splitEachLabel функция для случайного разделения 30% данных для использования при валидации и проверке. Наконец, создайте augmentedImageDatastore содержащие обучающие данные.

inputSize = [28 28 1];
spmd   
    digitDatasetPath = fullfile(digitDatasetPath,num2str(labindex - 1));
    imds = imageDatastore(digitDatasetPath,...
        'IncludeSubfolders',true,...
        'LabelSource','foldernames');
    [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized");
    
    augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
end

Чтобы протестировать эффективность комбинированной глобальной модели во время и после обучения, создайте наборы данных тестирования и валидации, содержащие изображения из всех классов. Объедините данные тестирования и валидации от каждого работника в один datastore. Затем разделите этот datastore на два хранилища данных, каждый из которых содержит 15% от общих данных - один для проверки сети во время обучения, а другой для тестирования сети после обучения.

fileList = [];
labelList = [];

for i = 1:numWorkers
    tmp = imdsTestVal{i};
    
    fileList = cat(1,fileList,tmp.Files);
    labelList = cat(1,labelList,tmp.Labels);    
end

imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;

[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");

augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);

Теперь данные организованы таким образом, что каждый рабочий процесс имеет данные из одного класса для обучения, а клиент проводит валидацию и тестовые данные от всех классов.

Определение сети

Определите количество классов в наборе данных.

classes = categories(imdsGlobalTest.Labels);
numClasses = numel(classes);

Определите сетевую архитектуру.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,32,'Name','conv1')
    reluLayer('Name','relu1')
    maxPooling2dLayer(2,'Name','maxpool1')
    convolution2dLayer(5,64,'Name','conv2')
    reluLayer('Name','relu2')
    maxPooling2dLayer(2,'Name','maxpool2')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];

lgraph = layerGraph(layers);

Создайте объект dlnetwork из графика слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [9×1 nnet.cnn.layer.Layer]
    Connections: [8×2 table]
     Learnables: [6×3 table]
          State: [0×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}
    Initialized: 1

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в разделе Model Gradients Function этого примера, который принимает dlnetwork объект и мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.

Задайте федеративную функцию усреднения

Создайте функцию federatedAveraging, перечисленный в разделе Federated Averaging Function этого примера, который берёт настраиваемые параметры сетей на каждом работнике и коэффициент нормализации для каждого работника и возвращает усредненные настраиваемые параметры во всех сетях. Используйте средние настраиваемые параметры для обновления глобальной сети и сети на каждом работнике.

Задайте функцию вычислительной точности

Создайте функцию computeAccuracy, перечисленный в разделе «Функция вычислительной точности» этого примера, который принимает dlnetwork объект, набор данных внутри minibatchqueue объект, и список классов, и возвращает точность предсказаний для всех наблюдений в наборе данных.

Настройка опций обучения

Во время обучения работники периодически передают клиенту свои настраиваемые параметры сети, чтобы клиент мог обновить глобальную модель. Обучение разделено на раунды. В конце каждого раунда обучения усредняются усвояемые параметры и обновляется глобальная модель. Модели рабочих затем заменяются новой глобальной моделью, и обучение продолжается для работников.

Тренируйтесь в течение 300 раундов, с 5 эпохами в раундах. Обучение небольшому числу эпох за раунд гарантирует, что сети на рабочих не будут различаться слишком далеко, прежде чем они будут усреднены.

numRounds = 300;
numEpochsperRound = 5;
miniBatchSize = 100;

Задайте опции для оптимизации SGD. Задайте начальную скорость обучения 0,001 и импульс 0.

learnRate = 0.001;
momentum = 0;

Обучите модель

Создайте указатель на функцию для пользовательской функции мини-пакетной предварительной обработки preprocessMiniBatch (определено в разделе функции мини-пакетной предварительной обработки этого примера).

На каждом работнике найдите общее количество обучающих наблюдений, обработанных локально на этом работнике. Используйте это число для нормализации настраиваемых параметров на каждом работнике, когда вы находите средние настраиваемые параметры после каждого раунда communicaton. Это помогает сбалансировать среднее значение, если существует различие между объемом данных по каждому работнику.

На каждом рабочем месте создайте minibatchqueue объект, который обрабатывает и управляет мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Предварительно обработайте данные с помощью пользовательской функции мини-пакетной предварительной обработки preprocessMiniBatch для преобразования меток в переменные с одним «горячим» кодированием.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам классов.

  • Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue объект преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

preProcess = @(x,y)preprocessMiniBatch(x,y,classes);

spmd
    sizeOfLocalDataset = augimdsTrain.NumObservations;
    
    mbq = minibatchqueue(augimdsTrain,...
        'MiniBatchSize',miniBatchSize,...
        'MiniBatchFcn',preProcess,...
        'MiniBatchFormat',{'SSCB',''});
end

Создайте minibatchqueue объект, который управляет данными валидации для использования во время обучения. Используйте те же настройки, что и minibatchqueue на каждого работника.

mbqGlobalVal = minibatchqueue(augimdsGlobalVal,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график процесса обучения.

figure
lineAccuracyTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Communication rounds")
ylabel("Accuracy (Global)")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Инициализируйте глобальную модель. Чтобы начать, глобальная модель имеет те же начальные параметры, что и необученная сеть на каждом рабочем.

globalModel = dlnet;

Обучите модель с помощью пользовательского цикла обучения. Для каждого раунда связи,

  • Обновляйте сети рабочих с помощью последней глобальной сети.

  • Обучите сети на рабочих в течение пяти эпох.

  • Найдите средние параметры всех сетей, использующих federatedAveraging функция.

  • Замените параметры глобальной сети средним значением.

  • Вычислите точность обновленной глобальной сети с помощью данных валидации.

  • Отображение процесса обучения.

Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели и потери с помощью dlfeval и modelGradients функций.

  • Обновите параметры локальной сети с помощью sgdmupdate функция.

start = tic;
for rounds = 1:numRounds
   
    spmd
        % Send global updated parameters to each worker.
        dlnet.Learnables.Value = globalModel.Learnables.Value;        
        
        % Loop over epochs.
        for epoch = 1:numEpochsperRound
            % Shuffle data.
            shuffle(mbq);
            
            % Loop over mini-batches.
            while hasdata(mbq)
                
                % Read mini-batch of data.
                [dlX,dlT] = next(mbq);
                
                % Evaluate the model gradients, state, and loss using dlfeval and the
                % modelGradients function and update the network state.
                [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT);
                
                % Update the network parameters using the SGDM optimizer.
                [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
                
            end
        end
        
        % Collect updated learnable parameters on each worker.
        workerLearnables = dlnet.Learnables.Value;
    end
    
    % Find normalization factors for each worker based on ratio of data
    % processed on that worker. 
    sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]);
    normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets;
    
    % Update the global model with new learnable parameters, normalized and
    % averaged across all workers.
    globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor);
    
    % Calculate the accuracy of the global model.
    accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes);
    
    % Display the training progress of the global model.
    D = duration(0,0,toc(start),'Format','hh:mm:ss');
    addpoints(lineAccuracyTrain,rounds,double(accuracy))
    title("Communication round: " + rounds + ", Elapsed: " + string(D))
    drawnow
end

После заключительного раунда обучения обновляйте сеть на каждом работнике с конечными средними настраиваемыми параметрами. Это важно, если вы хотите продолжать использовать или обучать сеть в рабочих местах.

spmd
    dlnet.Learnables.Value = globalModel.Learnables.Value;
end

Экспериментальная модель

Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками.

Создайте minibatchqueue объект, который управляет тестовыми данными. Используйте те же настройки, что и minibatchqueue объекты, используемые во время обучения и валидации.

mbqGlobalTest = minibatchqueue(augimdsGlobalTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',preProcess,...
    'MiniBatchFormat','SSCB');

Используйте computePredictions функция для вычисления предсказанных классов и вычисления точности предсказаний по всем тестовым данным.

accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
    0.9873

После завершения расчетов можно удалить параллельный пул. The gcp функция возвращает текущий объект параллельного пула, чтобы можно было удалить пул.

delete(gcp('nocreate'));

Функция градиентов модели

The modelGradients функция принимает dlnetwork dlnet объектамини-пакет входных данных dlX с соответствующими метками T и возвращает градиенты потерь относительно настраиваемых параметров в dlnet и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция. Чтобы вычислить предсказания сети во время обучения, используйте forward функция.

function [gradients,loss] = modelGradients(dlnet,dlX,T)

    dlYPred = forward(dlnet,dlX);
    
    loss = crossentropy(dlYPred,T);
    gradients = dlgradient(loss,dlnet.Learnables);

end

Функция вычислительной точности

The computePredictions функция принимает dlnetwork dlnet объекта, а minibatchqueue mbq объекта, и список классов, и возвращает точность всех предсказаний на предоставленном наборе данных. Чтобы вычислить предсказания сети во время валидации или после завершения обучения, используйте predict функция.

function accuracy = computeAccuracy(dlnet,mbq,classes)

    predictions = [];
    correctPredictions = [];
    
    shuffle(mbq);
    while hasdata(mbq)
        
        [dlXTest,dlTTest] = next(mbq);
        
        TTest = onehotdecode(dlTTest,classes,1)';
        
        dlYPred = predict(dlnet,dlXTest);
        YPred = onehotdecode(dlYPred,classes,1)';
        
        predictions = [predictions; YPred];
        correctPredictions = [correctPredictions; YPred == TTest];
    end
    
    predSum = sum(correctPredictions);
    accuracy = single(predSum./size(correctPredictions,1));

end

Функция мини-пакетной предварительной обработки

Функция preprocessMiniBatch предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и соедините в числовой массив. Конкатенация данных изображения по четвертому измерению добавляет третье измерение к каждому изображению, которое используется в качестве размерности одинарного канала.

  2. Извлеките данные метки из входящих массивов ячеек и соедините в категориальный массив по второму измерению.

  3. Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.

function [X,Y] = preprocessMiniBatch(XCell,YCell,classes)

    % Concatenate.
    X = cat(4,XCell{1:end});
    
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{1:end});
    
    % One-hot encode labels.
    Y = onehotencode(Y,1,'ClassNames',classes);

end

Федеративная функция усреднения

Функция federatedAveraging функция принимает настраиваемые параметры сетей на каждом работнике и коэффициент нормализации для каждого работника и возвращает усредненные настраиваемые параметры во всех сетях. Используйте средние настраиваемые параметры для обновления глобальной сети и сети на каждом работнике.

function learnables = federatedAveraging(workerLearnables,normalizationFactor)

    numWorkers = size(normalizationFactor,2);
    
    % Initialize container for averaged learnables with same size as existing
    % learnables. Use learnables of first worker network as an example.
    exampleLearnables = workerLearnables{1};
    learnables = cell(height(exampleLearnables),1);
    
    for i = 1:height(learnables)   
        learnables{i} = zeros(size(exampleLearnables{i}),'like',(exampleLearnables{i}));
    end
    
    % Add the normalized learnable parameters of all workers to
    % calculate average values.
    for i = 1:numWorkers
        tmp = workerLearnables{i};
        for values = 1:numel(learnables)
            learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values};
        end
    end
    
end

Ссылки

[1] McMahan, H. Brendan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Agüera y Arcas. «Коммуникационно-эффективное обучение глубоких сетей на основе децентрализованных данных». Препринт, отправлен. Февраль 2017 года. https://arxiv.org/abs/1602.05629.

См. также

| | | | | |

Похожие темы