parfeval Обучение нескольким сетям глубокого обученияВ этом примере показано, как использовать parfeval для выполнения сдвига параметров на глубине сетевой архитектуры для сети глубокого обучения и извлечения данных во время обучения.
Глубокое обучение часто занимает часы или дни, и поиск хороших архитектур может быть затруднен. Параллельные вычисления позволяют ускорить и автоматизировать поиск хороших моделей. Если у вас есть доступ к машине с несколькими графическими процессорами (GPU), вы можете завершить этот пример на локальной копии набора данных с локальным параллельным пулом. Если вы хотите использовать больше ресурсов, вы можете масштабировать углубленное обучение в облаке. В этом примере показано, как использовать parfeval для выполнения сдвига параметров на глубине сетевой архитектуры в кластере в облаке. Используя parfeval позволяет тренироваться в фоновом режиме, не блокируя MATLAB, и предоставляет опции для ранней остановки, если результаты удовлетворительны. Можно изменить сценарий, чтобы выполнить сдвиг параметра для любого другого параметра. Кроме того, в этом примере показано, как получить обратную связь от работников во время вычислений с помощью DataQueue.
Перед запуском этого примера необходимо настроить кластер и загрузить данные в облако. В MATLAB кластеры в облаке можно создавать непосредственно с рабочего стола MATLAB. На вкладке Главная в меню Параллельный выберите Создать кластеры и управление ими. В диспетчере профилей кластера щелкните Создать облачный кластер. Можно также использовать Cloud Center MathWorks для создания вычислительных кластеров и доступа к ним. Дополнительные сведения см. в разделе Начало работы с облачным центром. В этом примере убедитесь, что кластер установлен по умолчанию на вкладке Главная страница MATLAB (Parallel > Select a Default Cluster). После этого загрузите данные в ведро Amazon S3 и используйте их непосредственно из MATLAB. В этом примере используется копия набора данных CIFAR-10, который уже хранится в Amazon S3. Инструкции см. в разделе Загрузка данных глубокого обучения в облако (инструментарий глубокого обучения).
Загрузка обучающих и тестовых наборов данных из облака с помощью imageDatastore. Разбейте набор обучающих данных на наборы обучающих и валидационных данных и сохраните набор тестовых данных, чтобы проверить лучшую сеть из параметров sweep. В этом примере используется копия набора данных CIFAR-10, хранящегося в Amazon S3. Чтобы гарантировать, что работники имеют доступ к хранилищу данных в облаке, убедитесь, что переменные среды для учетных данных AWS заданы правильно. См. раздел Загрузка данных глубокого обучения в облако (инструментарий глубокого обучения).
imds = imageDatastore('s3://cifar10cloud/cifar10/train', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
Обучение сети дополненным данным изображения путем создания augmentedImageDatastore объект. Используйте случайные переводы и горизонтальные отражения. Увеличение объема данных помогает предотвратить переоборудование сети и запоминание точных деталей обучающих изображений.
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ... 'DataAugmentation',imageAugmenter, ... 'OutputSizeMode','randcrop');
Определите варианты обучения. Установите размер мини-партии и линейно масштабируйте начальную скорость обучения в соответствии с размером мини-партии. Установите частоту проверки таким образом, чтобы trainNetwork проверяет сеть один раз в эпоху.
miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency);
Укажите глубины сетевой архитектуры, на которых будет выполняться сдвиг параметров. Выполнение параллельного контроля параметров в нескольких сетях одновременно с использованием parfeval. Используйте цикл для итерации различных сетевых архитектур в развертке. Создание вспомогательной функции createNetworkArchitecture в конце сценария, который принимает входной аргумент для управления глубиной сети и создает архитектуру для CIFAR-10. Использовать parfeval для разгрузки вычислений, выполняемых trainNetwork к работнику в кластере. parfeval возвращает будущую переменную для хранения обученных сетей и обучающей информации при выполнении вычислений.
netDepths = 1:4; for idx = 1:numel(netDepths) networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
Starting parallel pool (parpool) using the 'MyCluster' profile ... Connected to the parallel pool (number of workers: 4).
parfeval не блокирует MATLAB, что означает возможность продолжения выполнения команд. В этом случае получите обученные сети и их обучающую информацию с помощью fetchOutputs на networksFuture. fetchOutputs функция ожидает завершения будущих переменных.
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);
Получение окончательной точности проверки сетей путем доступа к trainingInfo структура.
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.5600 77.2600 79.4000 78.6800
Выберите лучшую сеть с точки зрения точности. Проверьте его производительность по набору тестовых данных.
[~, I] = max(accuracies); bestNetwork = trainedNetworks(I(1)); YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7840
Вычислите матрицу путаницы для тестовых данных.
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]); confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');

Подготовьте и инициализируйте графики, показывающие ход обучения каждого из работников. Использовать animatedLine удобный способ отображения изменяющихся данных.
f = figure; f.Visible = true; for i=1:4 subplot(2,2,i) xlabel('Iteration'); ylabel('Training accuracy'); lines(i) = animatedline; end

Отправка данных о ходе обучения от работников клиенту с помощью DataQueue, а затем постройте график данных. Обновлять графики каждый раз, когда работники отправляют отзыв о ходе обучения с помощью afterEach. Параметр opts содержит сведения о работнике, итерации обучения и точности обучения.
D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));Укажите глубины сетевой архитектуры, на которых будет выполняться сдвиг параметров, и выполните сдвиг параллельных параметров с помощью parfeval. Разрешить работникам доступ к любой функции помощника в этом сценарии, добавив сценарий в текущий пул в качестве присоединенного файла. Определите функцию вывода в параметрах обучения для отправки результатов обучения от работников клиенту. Варианты обучения зависят от индекса работника и должны быть включены в for цикл.
netDepths = 1:4; addAttachedFiles(gcp,mfilename); for idx = 1:numel(netDepths) miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size. validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client. 'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep. 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency); networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
parfeval призывает trainNetwork на работнике в кластере. Вычисления выполняются в фоновом режиме, что позволяет продолжить работу в MATLAB. Если вы хотите остановить parfeval вычисление, вы можете вызвать cancel на соответствующей будущей переменной. Например, если вы видите, что сеть недоработана, вы можете отменить ее будущее. При этом следующая будущая переменная, находящаяся в очереди, начинает свои вычисления.
В этом случае извлеките обученные сети и их обучающую информацию путем вызова fetchOutputs о будущих переменных.
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

Получите окончательную точность проверки для каждой сети.
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.9200 77.4800 76.9200 77.0400
Определите архитектуру сети для CIFAR-10 набора данных с помощью функции и используйте входной аргумент для настройки глубины сети. Чтобы упростить код, используйте сверточные блоки, которые свернут вход. Слои объединения уменьшают пространственные размеры.
function layers = createNetworkArchitecture(netDepth) imageSize = [32 32 3]; netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block layers = [ imageInputLayer(imageSize) convolutionalBlock(netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(2*netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(4*netWidth,netDepth) averagePooling2dLayer(8) fullyConnectedLayer(10) softmaxLayer classificationLayer ]; end
Определите функцию для создания сверточного блока в сетевой архитектуре.
function layers = convolutionalBlock(numFilters,numConvLayers) layers = [ convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer ]; layers = repmat(layers,numConvLayers,1); end
Определение функции для отправки данных о ходе обучения клиенту через DataQueue.
function sendTrainingProgress(D,idx,info) if info.State == "iteration" send(D,{idx,info.Iteration,info.TrainingAccuracy}); end end
Определите функцию обновления для обновления графиков при отправке работником промежуточного результата.
function updatePlot(lines,idx,iter,acc) addpoints(lines(idx),iter,acc); drawnow limitrate nocallbacks end
afterEach | imageDatastore | parfeval | trainingOptions (инструментарий для глубокого обучения) | trainNetwork (инструментарий для глубокого обучения)