parfeval
Обучение нескольких Нейронных сетей для глубокого обученияВ этом примере показано, как использовать parfeval
выполнить зачистку параметра на глубине сетевой архитектуры для нейронной сети для глубокого обучения и извлечь данные во время обучения.
Обучение глубокому обучению часто занимает часы или дни, и поиск хорошей архитектуры может оказаться трудным. С помощью параллельных вычислений можно ускорить и автоматизировать поиск хороших моделей. Если у вас есть доступ к машине с несколькими графическими модулями (GPU), можно завершить этот пример на локальной копии набора данных с локальным параллельным пулом. Если вы хотите использовать больше ресурсов, можно расширить процесс глубокого обучения до облака. В этом примере показано, как использовать parfeval
для выполнения протягивания параметра на глубине сетевой архитектуры в кластере в облаке. Использование parfeval
позволяет вам обучаться в фоновом режиме, не блокируя MATLAB, и предоставляет опции, чтобы остановить раньше, если результаты удовлетворительны. Можно изменить скрипт, чтобы выполнить сдвиг параметра по любому другому параметру. Кроме того, этот пример показывает, как получить обратную связь от работников во время расчетов с помощью DataQueue
.
Прежде чем вы сможете запустить этот пример, вам нужно сконфигурировать кластер и загрузить свои данные в облако. В MATLAB можно создавать кластеры в облаке непосредственно с рабочего стола MATLAB. На вкладке «Вкладке Home», в меню Parallel, выберите Create and Manage Clusters. В Диспетчере профилей кластеров щелкните Создать облако. Также можно использовать MathWorks Cloud Center для создания и доступа к вычислительным кластерам. Дополнительные сведения см. в разделе Начало работы с облачным центром. В данном примере убедитесь, что кластер установлен по умолчанию на вкладке MATLAB Home, в Parallel > Select a Default Cluster. После этого загрузите свои данные в блок S3 Amazon и используйте их непосредственно из MATLAB. Этот пример использует копию CIFAR-10 набора данных, который уже хранится в Amazon S3. Инструкции см. в разделе Загрузка данных глубокого обучения в облако.
Загрузите наборы обучающих и тестовых данных из облака с помощью imageDatastore
. Разделите набор обучающих данных на наборы обучения и валидации и сохраните набор тестовых данных, чтобы протестировать лучшую сеть от сдвига параметра. В этом примере вы используете копию CIFAR-10 набора данных, хранящегося в Amazon S3. Чтобы убедиться, что работники имеют доступ к datastore в облаке, убедитесь, что переменные окружения для учетных данных AWS заданы правильно. Смотрите Загрузку данных глубокого обучения в облако.
imds = imageDatastore('s3://cifar10cloud/cifar10/train', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
Обучите сеть с дополненными данными об изображениях путем создания augmentedImageDatastore
объект. Используйте случайные переводы и горизонтальные отражения. Увеличение количества данных помогает предотвратить сверхподбор кривой сети и запоминание точных деталей обучающих изображений.
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ... 'DataAugmentation',imageAugmenter, ... 'OutputSizeMode','randcrop');
Определите опции обучения. Установите размер мини-пакета и масштабируйте начальную скорость обучения линейно в соответствии с размером мини-пакета. Установите частоту валидации так, чтобы trainNetwork
проверяет сеть один раз в эпоху.
miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency);
Задайте глубины для сетевой архитектуры, на которых необходимо выполнить сдвиг параметра. Выполните параллельное тестирование параметров нескольких сетей одновременно с помощью parfeval
. Используйте цикл, чтобы выполнить итерацию через различные сетевые архитектуры в сдвиге. Создайте вспомогательную функцию createNetworkArchitecture
в конце скрипта, который принимает входной параметр, чтобы контролировать глубину сети и создает архитектуру для CIFAR-10. Использование parfeval
для разгрузки расчетов, выполненных trainNetwork
работнику в кластере. parfeval
возвращает будущую переменную для хранения обученных сетей и обучающей информации при выполнении расчетов.
netDepths = 1:4; for idx = 1:numel(netDepths) networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
Starting parallel pool (parpool) using the 'MyCluster' profile ... Connected to the parallel pool (number of workers: 4).
parfeval
не блокирует MATLAB, что означает, что можно продолжить выполнение команд. В этом случае получайте обученные сети и их обучающую информацию при помощи fetchOutputs
на networksFuture
. The fetchOutputs
функция ожидает, пока будущие переменные не закончат.
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);
Получите окончательные точности валидации сетей путем доступа к trainingInfo
структура.
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.5600 77.2600 79.4000 78.6800
Выберите лучшую сеть с точки зрения точности. Проверяйте его эффективность на соответствие тестовых данных набору.
[~, I] = max(accuracies); bestNetwork = trainedNetworks(I(1)); YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7840
Вычислите матрицу неточностей для тестовых данных.
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]); confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');
Подготовка и инициализация графиков, показывающих процесс обучения каждого из работников. Использование animatedLine
для удобного способа показать меняющиеся данные.
f = figure; f.Visible = true; for i=1:4 subplot(2,2,i) xlabel('Iteration'); ylabel('Training accuracy'); lines(i) = animatedline; end
Отправка данных о процессе обучения от работников клиенту при помощи DataQueue
, а затем постройте график данных. Обновляйте графики каждый раз, когда работники отправляют обратную связь о процессе обучения при помощи afterEach
. Значение параметра opts
содержит информацию о работнике, итерации обучения и точности обучения.
D = parallel.pool.DataQueue; afterEach(D, @(opts) updatePlot(lines, opts{:}));
Задайте глубины для сетевой архитектуры, на которых можно выполнить сдвиг параметра, и выполните сдвиг параллельного параметра с помощью parfeval
. Разрешить работникам доступ к любой вспомогательной функции в этом скрипте путем добавления скрипта к текущему пулу в качестве вложенного файла. Определите выходную функцию в опциях обучения, чтобы отправить процесс обучения от работников клиенту. Опции обучения зависят от индекса работника и должны быть включены в for
цикл.
netDepths = 1:4; addAttachedFiles(gcp,mfilename); for idx = 1:numel(netDepths) miniBatchSize = 128; initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size. validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize); options = trainingOptions('sgdm', ... 'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client. 'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep. 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency', validationFrequency); networksFuture(idx) = parfeval(@trainNetwork,2, ... augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options); end
parfeval
вызывает trainNetwork
на рабочем месте в кластере. Расчеты происходят на фоне, поэтому можно продолжить работу в MATLAB. Если вы хотите остановить parfeval
расчеты, можно вызвать cancel
о соответствующей переменной будущего. Для примера, если вы наблюдаете, что сеть не работает, можно отменить ее будущее. Когда вы это делаете, следующая поставленная в очередь будущая переменная начинает свои расчеты.
В этом случае извлеките обученные сети и их обучающую информацию путем вызова fetchOutputs
о будущих переменных.
[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);
Получите окончательную точность валидации для каждой сети.
accuracies = [trainingInfo.FinalValidationAccuracy]
accuracies = 1×4
72.9200 77.4800 76.9200 77.0400
Задайте сетевую архитектуру для CIFAR-10 набора данных с функцией и используйте входной параметр, чтобы настроить глубину сети. Чтобы упростить код, используйте сверточные блоки, которые свертывают вход. Слои объединения понижают пространственные размерности.
function layers = createNetworkArchitecture(netDepth) imageSize = [32 32 3]; netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block layers = [ imageInputLayer(imageSize) convolutionalBlock(netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(2*netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(4*netWidth,netDepth) averagePooling2dLayer(8) fullyConnectedLayer(10) softmaxLayer classificationLayer ]; end
Задайте функцию для создания сверточного блока в сетевой архитектуре.
function layers = convolutionalBlock(numFilters,numConvLayers) layers = [ convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer ]; layers = repmat(layers,numConvLayers,1); end
Определите функцию для отправки процесса обучения клиенту через DataQueue
.
function sendTrainingProgress(D,idx,info) if info.State == "iteration" send(D,{idx,info.Iteration,info.TrainingAccuracy}); end end
Задайте функцию обновления, чтобы обновить графики, когда рабочий отправляет промежуточный результат.
function updatePlot(lines,idx,iter,acc) addpoints(lines(idx),iter,acc); drawnow limitrate nocallbacks end
imageDatastore
| trainingOptions
| trainNetwork
| afterEach
(Parallel Computing Toolbox) | parfeval
(Parallel Computing Toolbox)