В этом примере показано, как обучить сеть с помощью объединенного в федерацию изучения. Федеративное изучение является методом, который позволяет вам обучить сеть распределенным, децентрализованным способом [1].
Федеративное изучение позволяет вам обучать модель с помощью данных из других источников, не перемещая данные в центральное расположение, даже если отдельные источники данных не совпадают с полным распределением набора данных. Это известно как зависимые и тождественно распределенные (non-IID) данные. Федеративное изучение может быть особенно полезным, когда обучающие данные являются большими, или когда существуют опасения конфиденциальности по поводу передачи обучающих данных.
Вместо распределительных данных федеративный метод изучения обучает многоуровневые модели, каждого в том же месте как источник данных. Можно создать глобальную модель, которая извлекла уроки из всех источников данных путем периодического сбора и объединения настраиваемых параметров локально обученных моделей. Таким образом можно обучить глобальную модель, централизованно не обрабатывая обучающих данных.
Этот пример использование объединил обучение в федерацию обучить модель классификации в параллели с помощью высоко non-IID набор данных. Модель обучена с помощью набора данных цифр, который состоит из 10 000 рукописных изображений чисел от 0 до 9. Пример запускается в параллели с помощью 10 рабочих, каждая обработка изображения одной цифры. Путем усреднения настраиваемых параметров сетей после каждого раунда обучения модели на каждом рабочем улучшают производительность через все классы, никогда не обрабатывая данные других классов.
В то время как конфиденциальность данных является одним из приложений федеративного изучения, этот пример не имеет дело с деталями поддержания конфиденциальности данных и безопасности. Этот пример демонстрирует основной федеративный алгоритм обучения.
Создайте параллельный пул с тем же количеством рабочих как классы в наборе данных. В данном примере используйте локальный параллельный пул с 10 рабочими.
pool = parpool('local',10);
numWorkers = pool.NumWorkers;
Все данные, используемые в этом примере, первоначально хранимы в централизованном месте. Чтобы сделать эти данные высоко non-IID, необходимо распределить данные среди рабочих согласно классу. Чтобы создать валидацию и наборы тестовых данных, передайте фрагмент данных от рабочих клиенту. После того, как данные правильно настраиваются с обучающими данными отдельных классов на рабочих и тесте и данных о валидации всех классов на клиенте, нет никакой дальнейшей передачи данных во время обучения.
Задайте папку, содержащую данные изображения.
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos',... 'nndatasets','DigitDataset');
Распределите данные среди рабочих. Каждый рабочий получает изображения только одной цифры, такой, что рабочий 1 получает все изображения номера 0, рабочий 2 получает изображения номера 1 и т.д.
Изображения каждой цифры хранятся в разделять папке с именем той цифры. На каждом рабочем используйте fullfile
функция, чтобы задать путь к определенной папке класса. Затем создайте imageDatastore
это содержит все изображения той цифры. Затем используйте splitEachLabel
функционируйте к случайным образом отдельным 30% данных для использования в валидации и тестировании. Наконец, создайте augmentedImageDatastore
содержа обучающие данные.
inputSize = [28 28 1]; spmd digitDatasetPath = fullfile(digitDatasetPath,num2str(labindex - 1)); imds = imageDatastore(digitDatasetPath,... 'IncludeSubfolders',true,... 'LabelSource','foldernames'); [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized"); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); end
Чтобы проверить производительность объединенной глобальной модели в течение и после обучения, создайте тест и наборы данных валидации, содержащие изображения от всех классов. Объедините тест и данные о валидации от каждого рабочего в один datastore. Затем разделите этот datastore в два хранилища данных, что каждый содержит 15% полных данных - один для того, чтобы проверить сеть во время обучения и другого для тестирования сети после обучения.
fileList = []; labelList = []; for i = 1:numWorkers tmp = imdsTestVal{i}; fileList = cat(1,fileList,tmp.Files); labelList = cat(1,labelList,tmp.Labels); end imdsGlobalTestVal = imageDatastore(fileList); imdsGlobalTestVal.Labels = labelList; [imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized"); augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest); augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
Данные теперь располагаются таким образом, что у каждого рабочего есть данные из единого класса, чтобы обучаться на, и клиент содержит валидацию и тестовые данные от всех классов.
Определите количество классов в наборе данных.
classes = categories(imdsGlobalTest.Labels); numClasses = numel(classes);
Определить сетевую архитектуру.
layers = [ imageInputLayer(inputSize,'Normalization','none','Name','input') convolution2dLayer(5,32,'Name','conv1') reluLayer('Name','relu1') maxPooling2dLayer(2,'Name','maxpool1') convolution2dLayer(5,64,'Name','conv2') reluLayer('Name','relu2') maxPooling2dLayer(2,'Name','maxpool2') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')];
Создайте объект dlnetwork из слоев.
dlnet = dlnetwork(layers)
dlnet = dlnetwork with properties: Layers: [9×1 nnet.cnn.layer.Layer] Connections: [8×2 table] Learnables: [6×3 table] State: [0×3 table] InputNames: {'input'} OutputNames: {'softmax'} Initialized: 1
Создайте функциональный modelGradients
, перечисленный в разделе Model Gradients Function этого примера, который берет dlnetwork
возразите и мини-пакет входных данных с соответствующими метками, и возвращает градиенты потери относительно настраиваемых параметров в сети и соответствующей потери.
Создайте функциональный federatedAveraging
, перечисленный в раздел Federated Averaging Function этого примера, который берет настраиваемые параметры сетей на каждом рабочем и коэффициенте нормализации для каждого рабочего, и возвращает усредненные настраиваемые параметры через все сети. Используйте средние настраиваемые параметры, чтобы обновить глобальную сеть и сеть на каждом рабочем.
Создайте функциональный computeAccuracy
, перечисленный в разделе Compute Accuracy Function этого примера, который берет dlnetwork
объект, набор данных в minibatchqueue
объект и список классов, и возвращают точность предсказаний через все наблюдения в наборе данных.
Во время обучения рабочие периодически передают свои сетевые настраиваемые параметры клиенту, так, чтобы клиент мог обновить глобальную модель. Обучение разделено на раунды. В конце каждого раунда обучения усреднены настраиваемые параметры, и глобальная модель обновляется. Модели рабочего затем заменяются новой глобальной моделью, и обучение продвигается рабочие.
Обучайтесь для 300 раундов с 5 эпохами на раунд. Обучение маленькому numer эпох на раунд гарантирует, что сети на рабочих не отличаются слишком далеко, прежде чем они будут усреднены.
numRounds = 300; numEpochsperRound = 5; miniBatchSize = 100;
Задайте опции для оптимизации SGD. Укажите, что начальная буква изучает уровень 0,001 и импульс 0.
learnRate = 0.001; momentum = 0;
Создайте указатель на функцию к пользовательскому мини-пакету, предварительно обрабатывающему функциональный preprocessMiniBatch
(заданный в разделе Mini-Batch Preprocessing Function этого примера).
На каждом рабочем найдите общее количество учебных наблюдений обработанным локально на том рабочем. Используйте этот номер, чтобы нормировать настраиваемые параметры на каждом рабочем, когда вы найдете средние настраиваемые параметры после каждой коммуникации вокруг. Это помогает сбалансировать среднее значение, если существует различие между объемом данных по каждому рабочему.
На каждом рабочем создайте minibatchqueue
возразите, что процессы и управляют мини-пакетами изображений во время обучения. Для каждого мини-пакета:
Предварительно обработайте данные с помощью пользовательского мини-пакета, предварительно обрабатывающего функциональный preprocessMiniBatch
преобразовывать метки в одногорячие закодированные переменные.
Формат данные изображения с размерностью маркирует 'SSCB'
(пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue
объект преобразует данные в dlarray
объекты с базовым типом single
. Не добавляйте формат в метки класса.
Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue
объект преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
preProcess = @(x,y)preprocessMiniBatch(x,y,classes); spmd sizeOfLocalDataset = augimdsTrain.NumObservations; mbq = minibatchqueue(augimdsTrain,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat',{'SSCB',''}); end
Создайте minibatchqueue
объект, который управляет данными о валидации, чтобы использовать во время обучения. Используйте те же настройки в качестве minibatchqueue
на каждом рабочем.
mbqGlobalVal = minibatchqueue(augimdsGlobalVal,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat',{'SSCB',''});
Инициализируйте график процесса обучения.
figure lineAccuracyTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Communication rounds") ylabel("Accuracy (Global)") grid on
Инициализируйте скоростной параметр для решателя SGDM.
velocity = [];
Инициализируйте глобальную модель. Чтобы запуститься, глобальная модель имеет те же начальные параметры как нетренированная сеть на каждом рабочем.
globalModel = dlnet;
Обучите модель с помощью пользовательского учебного цикла. Для каждой коммуникации вокруг,
Обновите сети на рабочих с последней глобальной сетью.
Обучите сети на рабочих в течение пяти эпох.
Найдите средние параметры всех сетей с помощью federatedAveraging
функция.
Замените глобальные сетевые параметры на среднее значение.
Вычислите точность обновленной глобальной сети с помощью данных о валидации.
Отобразите прогресс обучения.
В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Для каждого мини-пакета:
Оцените градиенты модели и потерю с помощью dlfeval
и modelGradients
функции.
Обновите локальные сетевые параметры с помощью sgdmupdate
функция.
start = tic; for rounds = 1:numRounds spmd % Send global updated parameters to each worker. dlnet.Learnables.Value = globalModel.Learnables.Value; % Loop over epochs. for epoch = 1:numEpochsperRound % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. [dlX,dlT] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelGradients function and update the network state. [gradients,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT); % Update the network parameters using the SGDM optimizer. [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum); end end % Collect updated learnable parameters on each worker. workerLearnables = dlnet.Learnables.Value; end % Find normalization factors for each worker based on ratio of data % processed on that worker. sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]); normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets; % Update the global model with new learnable parameters, normalized and % averaged across all workers. globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor); % Calculate the accuracy of the global model. accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes); % Display the training progress of the global model. D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineAccuracyTrain,rounds,double(accuracy)) title("Communication round: " + rounds + ", Elapsed: " + string(D)) drawnow end
После финального раунда обучения обновите сеть на каждом рабочем с итоговыми средними настраиваемыми параметрами. Это важно, если вы хотите продолжить использовать или обучать сеть на рабочих.
spmd dlnet.Learnables.Value = globalModel.Learnables.Value; end
Протестируйте точность классификации модели путем сравнения предсказаний на наборе тестов с истинными метками.
Создайте minibatchqueue
объект, который управляет тестовыми данными. Используйте те же настройки в качестве minibatchqueue
объекты используются во время обучения и валидации.
mbqGlobalTest = minibatchqueue(augimdsGlobalTest,... 'MiniBatchSize',miniBatchSize,... 'MiniBatchFcn',preProcess,... 'MiniBatchFormat','SSCB');
Используйте computePredictions
функция, чтобы вычислить предсказанные классы и вычислить точность предсказаний через все тестовые данные.
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
0.9873
После того, как вы будете сделаны с вашими расчетами, можно удалить параллельный пул. gcp
функция возвращает текущий параллельный объект пула, таким образом, можно удалить пул.
delete(gcp('nocreate'));
modelGradients
функционируйте берет dlnetwork
объект dlnet
, мини-пакет входных данных dlX
с соответствием маркирует T
и возвращает градиенты потери относительно настраиваемых параметров в dlnet
и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient
функция. Чтобы вычислить предсказания сети во время обучения, используйте forward
функция.
function [gradients,loss] = modelGradients(dlnet,dlX,T) dlYPred = forward(dlnet,dlX); loss = crossentropy(dlYPred,T); gradients = dlgradient(loss,dlnet.Learnables); end
computePredictions
функционируйте берет dlnetwork
объект dlnet
, minibatchqueue
объект mbq
, и список классов, и возвращает точность всех предсказаний на обеспеченном наборе данных. Чтобы вычислить предсказания сети во время валидации или после, обучение закончено, используйте predict
функция.
function accuracy = computeAccuracy(dlnet,mbq,classes) correctPredictions = []; shuffle(mbq); while hasdata(mbq) [dlXTest,dlTTest] = next(mbq); TTest = onehotdecode(dlTTest,classes,1)'; dlYPred = predict(dlnet,dlXTest); YPred = onehotdecode(dlYPred,classes,1)'; correctPredictions = [correctPredictions; YPred == TTest]; end predSum = sum(correctPredictions); accuracy = single(predSum./size(correctPredictions,1)); end
Функция preprocessMiniBatch предварительно обрабатывает данные с помощью следующих шагов:
Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.
Извлеките данные о метке из массивов входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения.
Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.
function [X,Y] = preprocessMiniBatch(XCell,YCell,classes) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1,'ClassNames',classes); end
Функциональный federatedAveraging
функционируйте берет настраиваемые параметры сетей на каждом рабочем и коэффициенте нормализации для каждого рабочего, и возвращает усредненные настраиваемые параметры через все сети. Используйте средние настраиваемые параметры, чтобы обновить глобальную сеть и сеть на каждом рабочем.
function learnables = federatedAveraging(workerLearnables,normalizationFactor) numWorkers = size(normalizationFactor,2); % Initialize container for averaged learnables with same size as existing % learnables. Use learnables of first worker network as an example. exampleLearnables = workerLearnables{1}; learnables = cell(height(exampleLearnables),1); for i = 1:height(learnables) learnables{i} = zeros(size(exampleLearnables{i}),'like',(exampleLearnables{i})); end % Add the normalized learnable parameters of all workers to % calculate average values. for i = 1:numWorkers tmp = workerLearnables{i}; for values = 1:numel(learnables) learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values}; end end end
dlarray
| dlnetwork
| sgdmupdate
| dlupdate
| dlfeval
| dlgradient
| minibatchqueue