В этом примере показано, как обучить сеть с помощью федеративного обучения. Федеративное обучение является методом, который позволяет вам обучать сеть распределенным, децентрализованным способом [1].
Федеративное обучение позволяет вам обучать модель с помощью данных из разных источников, не перемещая данные в центральное место, даже если отдельные источники данных не совпадают с общим распределением набора данных. Это известно как несамостоятельные и идентично распределенные (не-IID) данные. Федеративное обучение может быть особенно полезным, когда обучающие данные большие, или когда существуют проблемы конфиденциальности при передаче обучающих данных.
Вместо распределения данных метод федеративного обучения обучает несколько моделей, каждая из которых находится в том же месте, что и источник данных. Можно создать глобальную модель, которая извлечена из всех источников данных путем периодического сбора и объединения настраиваемых параметров локально обученных моделей. Таким образом, можно обучить глобальную модель, не обрабатывая централизованно какие-либо обучающие данные.
Этот пример использует федеративное обучение, чтобы параллельно обучить классификационную модель с использованием набора данных, не являющегося IID. Модель обучается с помощью набора данных цифр, который состоит из 10000 рукописных изображений цифр от 0 до 9. Пример запускается параллельно с использованием 10 рабочих мест, каждый обрабатывает изображения с одной цифрой. Путем усреднения настраиваемых параметров сетей после каждого цикла обучения, модели на каждом работнике улучшают эффективность во всех классах, никогда не обрабатывая данные других классов.
Хотя конфиденциальность данных является одним из приложений федеративного обучения, этот пример не касается деталей поддержания конфиденциальности и безопасности данных. Этот пример демонстрирует основной алгоритм федеративного обучения.
Создайте параллельный пул с таким же количеством работников, как и классы в наборе данных. В данном примере используйте локальный параллельный пул с 10 рабочими.
pool = parpool('local',10);
numWorkers = pool.NumWorkers;
Все данные, используемые в этом примере, первоначально хранятся в централизованном местоположении. Чтобы сделать эти данные сильно отличными от IID, необходимо распределить данные между работниками в соответствии с классом. Чтобы создать наборы валидации и тестовых данных, передайте фрагментов данных от работников клиенту. После правильной настройки данных, с обучающими данными отдельных классов по работникам и тестовыми и валидационными данными всех классов на клиенте, дальнейшая передача данных во время обучения отсутствует.
Укажите папку, содержащую данные изображения.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',... 'nndatasets','DigitDataset');
Распределите данные между работниками. Каждый рабочий получает изображения только одной цифры, так что рабочий 1 получает все изображения числа 0, рабочий 2 получает изображения числа 1 и т.д.
Изображения каждой цифры хранятся в отдельной папке с именем этой цифры. На каждом работнике используйте fullfile
функция для задания пути к определенной папке класса. Затем создайте imageDatastore
который содержит все изображения этой цифры. Далее используйте splitEachLabel
функция для случайного разделения 30% данных для использования при валидации и проверке. Наконец, создайте augmentedImageDatastore
содержащие обучающие данные.
inputSize = [28 28 1]; spmd digitDatasetPath = fullfile(digitDatasetPath,num2str(labindex - 1)); imds = imageDatastore(digitDatasetPath,... 'IncludeSubfolders',true,... 'LabelSource','foldernames'); [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized"); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); end
Чтобы протестировать эффективность комбинированной глобальной модели во время и после обучения, создайте наборы данных тестирования и валидации, содержащие изображения из всех классов. Объедините данные тестирования и валидации от каждого работника в один datastore. Затем разделите этот datastore на два хранилища данных, каждый из которых содержит 15% от общих данных - один для проверки сети во время обучения, а другой для тестирования сети после обучения.
fileList = []; labelList = []; for i = 1:numWorkers tmp = imdsTestVal{i}; fileList = cat(1,fileList,tmp.Files); labelList = cat(1,labelList,tmp.Labels); end imdsGlobalTestVal = imageDatastore(fileList); imdsGlobalTestVal.Labels = labelList; [imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized"); augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest); augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
Теперь данные организованы таким образом, что каждый рабочий процесс имеет данные из одного класса для обучения, а клиент проводит валидацию и тестовые данные от всех классов.
Определите количество классов в наборе данных.
classes = categories(imdsGlobalTest.Labels); numClasses = numel(classes);
Определите сетевую архитектуру.
layers = [ imageInputLayer(inputSize,'Normalization','none','Name','input') convolution2dLayer(5,32,'Name','conv1') reluLayer('Name','relu1') maxPooling2dLayer(2,'Name','maxpool1') convolution2dLayer(5,64,'Name','conv2') reluLayer('Name','relu2') maxPooling2dLayer(2,'Name','maxpool2') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);
Создайте объект dlnetwork из графика слоев.
dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [9×1 nnet.cnn.layer.Layer] Connections: [8×2 table] Learnables: [6×3 table] State: [0×3 table] InputNames: {'input'} OutputNames: {'softmax'} Initialized: 1
Создайте функцию modelGradients
, перечисленный в разделе Model Gradients Function этого примера, который принимает dlnetwork
объект и мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.
Создайте функцию federatedAveraging
, перечисленный в разделе Federated Averaging Function этого примера, который берёт настраиваемые параметры сетей на каждом работнике и коэффициент нормализации для каждого работника и возвращает усредненные настраиваемые параметры во всех сетях. Используйте средние настраиваемые параметры для обновления глобальной сети и сети на каждом работнике.
Создайте функцию computeAccuracy
, перечисленный в разделе «Функция вычислительной точности» этого примера, который принимает dlnetwork
объект, набор данных внутри minibatchqueue
объект, и список классов, и возвращает точность предсказаний для всех наблюдений в наборе данных.
Во время обучения работники периодически передают клиенту свои настраиваемые параметры сети, чтобы клиент мог обновить глобальную модель. Обучение разделено на раунды. В конце каждого раунда обучения усредняются усвояемые параметры и обновляется глобальная модель. Модели рабочих затем заменяются новой глобальной моделью, и обучение продолжается для работников.
Тренируйтесь в течение 300 раундов, с 5 эпохами в раундах. Обучение небольшому числу эпох за раунд гарантирует, что сети на рабочих не будут различаться слишком далеко, прежде чем они будут усреднены.
numRounds = 300; numEpochsperRound = 5; miniBatchSize = 100;
Задайте опции для оптимизации SGD. Задайте начальную скорость обучения 0,001 и импульс 0.
learnRate = 0.001; momentum = 0;
Создайте указатель на функцию для пользовательской функции мини-пакетной предварительной обработки preprocessMiniBatch
(определено в разделе функции мини-пакетной предварительной обработки этого примера).
На каждом работнике найдите общее количество обучающих наблюдений, обработанных локально на этом работнике. Используйте это число для нормализации настраиваемых параметров на каждом работнике, когда вы находите средние настраиваемые параметры после каждого раунда communicaton. Это помогает сбалансировать среднее значение, если существует различие между объемом данных по каждому работнику.
На каждом рабочем месте создайте minibatchqueue
объект, который обрабатывает и управляет мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Предварительно обработайте данные с помощью пользовательской функции мини-пакетной предварительной обработки preprocessMiniBatch
для преобразования меток в переменные с одним «горячим» кодированием.
Форматируйте данные изображения с помощью меток размерностей 'SSCB'
(пространственный, пространственный, канальный, пакетный). По умолчанию в minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат к меткам классов.
Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue
объект преобразует каждый выход в gpuArray
при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
preProcess = @(x,y)preprocessMiniBatch(x,y,classes); spmd sizeOfLocalDataset = augimdsTrain.NumObservations; mbq = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat',{'SSCB',''}); end
Создайте minibatchqueue
объект, который управляет данными валидации для использования во время обучения. Используйте те же настройки, что и minibatchqueue
на каждого работника.
mbqGlobalVal = minibatchqueue(augimdsGlobalVal,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat',{'SSCB',''});
Инициализируйте график процесса обучения.
figure lineAccuracyTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Communication rounds") ylabel("Accuracy (Global)") grid on
Инициализируйте параметр скорости для решателя SGDM.
velocity = [];
Инициализируйте глобальную модель. Чтобы начать, глобальная модель имеет те же начальные параметры, что и необученная сеть на каждом рабочем.
globalModel = dlnet;
Обучите модель с помощью пользовательского цикла обучения. Для каждого раунда связи,
Обновляйте сети рабочих с помощью последней глобальной сети.
Обучите сети на рабочих в течение пяти эпох.
Найдите средние параметры всех сетей, использующих federatedAveraging
функция.
Замените параметры глобальной сети средним значением.
Вычислите точность обновленной глобальной сети с помощью данных валидации.
Отображение процесса обучения.
Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Для каждого мини-пакета:
Оцените градиенты модели и потери с помощью dlfeval
и modelGradients
функций.
Обновите параметры локальной сети с помощью sgdmupdate
функция.
start = tic; for rounds = 1:numRounds spmd % Send global updated parameters to each worker. dlnet.Learnables.Value = globalModel.Learnables.Value; % Loop over epochs. for epoch = 1:numEpochsperRound % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. [dlX,dlT] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT); % Update the network parameters using the SGDM optimizer. [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum); end end % Collect updated learnable parameters on each worker. workerLearnables = dlnet.Learnables.Value; end % Find normalization factors for each worker based on ratio of data % processed on that worker. sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]); normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets; % Update the global model with new learnable parameters, normalized and % averaged across all workers. globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor); % Calculate the accuracy of the global model. accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes); % Display the training progress of the global model. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineAccuracyTrain,rounds,double(accuracy)) title("Communication round: " + rounds + ", Elapsed: " + string(D)) drawnow end
После заключительного раунда обучения обновляйте сеть на каждом работнике с конечными средними настраиваемыми параметрами. Это важно, если вы хотите продолжать использовать или обучать сеть в рабочих местах.
spmd dlnet.Learnables.Value = globalModel.Learnables.Value; end
Протестируйте классификационную точность модели путем сравнения предсказаний на тестовом наборе с истинными метками.
Создайте minibatchqueue
объект, который управляет тестовыми данными. Используйте те же настройки, что и minibatchqueue
объекты, используемые во время обучения и валидации.
mbqGlobalTest = minibatchqueue(augimdsGlobalTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat','SSCB');
Используйте computePredictions
функция для вычисления предсказанных классов и вычисления точности предсказаний по всем тестовым данным.
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
0.9873
После завершения расчетов можно удалить параллельный пул. The gcp
функция возвращает текущий объект параллельного пула, чтобы можно было удалить пул.
delete(gcp('nocreate'));
The modelGradients
функция принимает dlnetwork
dlnet объекта
мини-пакет входных данных dlX
с соответствующими метками T
и возвращает градиенты потерь относительно настраиваемых параметров в dlnet
и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция. Чтобы вычислить предсказания сети во время обучения, используйте forward
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,T) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,T); gradients = dlgradient(loss,dlnet.Learnables); end
The computePredictions
функция принимает dlnetwork
dlnet объекта
, а minibatchqueue
mbq объекта
, и список классов, и возвращает точность всех предсказаний на предоставленном наборе данных. Чтобы вычислить предсказания сети во время валидации или после завершения обучения, используйте predict
функция.
function accuracy = computeAccuracy(dlnet,mbq,classes) predictions = []; correctPredictions = []; shuffle(mbq); while hasdata(mbq) [dlXTest,dlTTest] = next(mbq); TTest = onehotdecode(dlTTest,classes,1)'; dlYPred = predict(dlnet,dlXTest); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; correctPredictions = [correctPredictions; YPred == TTest]; end predSum = sum(correctPredictions); accuracy = single(predSum./size(correctPredictions,1)); end
Функция preprocessMiniBatch предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из входящего массива ячеек и соедините в числовой массив. Конкатенация данных изображения по четвертому измерению добавляет третье измерение к каждому изображению, которое используется в качестве размерности одинарного канала.
Извлеките данные метки из входящих массивов ячеек и соедините в категориальный массив по второму измерению.
Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.
function [X,Y] = preprocessMiniBatch(XCell,YCell,classes) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1,'ClassNames',classes); end
Функция federatedAveraging
функция принимает настраиваемые параметры сетей на каждом работнике и коэффициент нормализации для каждого работника и возвращает усредненные настраиваемые параметры во всех сетях. Используйте средние настраиваемые параметры для обновления глобальной сети и сети на каждом работнике.
function learnables = federatedAveraging(workerLearnables,normalizationFactor) numWorkers = size(normalizationFactor,2); % Initialize container for averaged learnables with same size as existing % learnables. Use learnables of first worker network as an example. exampleLearnables = workerLearnables{1}; learnables = cell(height(exampleLearnables),1); for i = 1:height(learnables) learnables{i} = zeros(size(exampleLearnables{i}),'like',(exampleLearnables{i})); end % Add the normalized learnable parameters of all workers to % calculate average values. for i = 1:numWorkers tmp = workerLearnables{i}; for values = 1:numel(learnables) learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values}; end end end
dlarray
| dlfeval
| dlgradient
| dlnetwork
| dlupdate
| minibatchqueue
| sgdmupdate