forward

Вычислите нейронную сеть для глубокого обучения выход для обучения

Описание

Некоторые слои глубокого обучения ведут себя по-другому во время обучения и заключают (предсказание). Например, во время обучения, слои уволенного случайным образом обнуляют входные элементы, чтобы помочь предотвратить сверхподбор кривой, но во время вывода, слои уволенного не изменяют вход.

Чтобы вычислить сетевые выходные параметры для обучения, используйте forward функция. Чтобы вычислить сетевые выходные параметры для вывода, используйте predict функция.

пример

dlY = forward(dlnet,dlX) возвращает сетевой выход dlY во время обучения, учитывая входные данные dlX.

dlY = forward(dlnet,dlX1,...,dlXM) возвращает сетевой выход dlY во время обучения, учитывая M входные параметры dlX1, ...,dlXM и сеть dlnet это имеет M входные параметры и один выход.

[dlY1,...,dlYN] = forward(___) возвращает N выходные параметры dlY1, …, dlYN во время обучения сетям, которые имеют N выходные параметры с помощью любого из предыдущих синтаксисов.

[dlY1,...,dlYK] = forward(___,'Outputs',layerNames) возвращает выходные параметры dlY1, …, dlYK во время обучения заданным слоям с помощью любого из предыдущих синтаксисов.

[___] = forward(___,'Acceleration',acceleration) также задает оптимизацию эффективности, чтобы использовать во время обучения, в дополнение к входным параметрам в предыдущих синтаксисах.

[___,state] = forward(___) также возвращает обновленное сетевое состояние.

Примеры

свернуть все

В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры с пользовательским расписанием скорости обучения.

Если trainingOptions не предоставляет возможности, в которых вы нуждаетесь (например, пользовательское расписание скорости обучения), затем можно задать собственный учебный цикл с помощью автоматического дифференцирования.

Этот пример обучает сеть, чтобы классифицировать рукописные цифры с основанным на времени расписанием скорости обучения затухания: для каждой итерации решатель использует скорость обучения, данную ρt=ρ01+kt, где t является номером итерации, ρ0 начальная скорость обучения, и k является затуханием.

Загрузите обучающие данные

Загрузите данные о цифрах как datastore изображений с помощью imageDatastore функционируйте и задайте папку, содержащую данные изображения.

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....
    'LabelSource','foldernames');

Разделите данные в наборы обучения и валидации. Отложите 10% данных для валидации с помощью splitEachLabel функция.

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');

Сеть, используемая в этом примере, требует входных изображений размера 28 28 1. Чтобы автоматически изменить размер учебных изображений, используйте увеличенный datastore изображений. Задайте дополнительные операции увеличения, чтобы выполнить на учебных изображениях: случайным образом переведите изображения до 5 пикселей в горизонтальных и вертикальных осях. Увеличение данных помогает препятствовать тому, чтобы сеть сверхсоответствовала и запомнила точные детали учебных изображений.

inputSize = [28 28 1];
pixelRange = [-5 5];
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

Чтобы автоматически изменить размер изображений валидации, не выполняя дальнейшее увеличение данных, используйте увеличенный datastore изображений, не задавая дополнительных операций предварительной обработки.

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

Определите количество классов в обучающих данных.

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

Сеть Define

Задайте сеть для классификации изображений.

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект от графика слоев.

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет dlnetwork объект, мини-пакет входных данных с соответствующими метками и возвращают градиенты потери относительно настраиваемых параметров в сети и соответствующей потери.

Задайте опции обучения

Обучайтесь в течение десяти эпох с мини-пакетным размером 128.

numEpochs = 10;
miniBatchSize = 128;

Задайте опции для оптимизации SGDM. Укажите, что начальная буква изучает уровень 0,01 с затуханием 0,01, и импульс 0.9.

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Обучите модель

Создайте minibatchqueue возразите, что процессы и управляют мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера), чтобы преобразовать метки в одногорячие закодированные переменные.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат в метки класса.

  • Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

mbq = minibatchqueue(augimdsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте скоростной параметр для решателя SGDM.

velocity = [];

Обучите сеть с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели, состояние и потерю с помощью dlfeval и modelGradients функции и обновление сетевое состояние.

  • Определите скорость обучения для основанного на времени расписания скорости обучения затухания.

  • Обновите сетевые параметры с помощью sgdmupdate функция.

  • Отобразите прогресс обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестовая модель

Протестируйте точность классификации модели путем сравнения предсказаний на наборе валидации с истинными метками.

После обучения создание предсказаний на новых данных не требует меток. Создайте minibatchqueue объект, содержащий только предикторы тестовых данных:

  • Чтобы проигнорировать метки для тестирования, определите номер выходных параметров мини-пакетной очереди к 1.

  • Задайте тот же мини-пакетный размер, используемый для обучения.

  • Предварительно обработайте предикторы с помощью preprocessMiniBatchPredictors функция, перечисленная в конце примера.

  • Для одного выхода datastore задайте мини-пакетный формат 'SSCB' (пространственный, пространственный, канал, пакет).

numOutputs = 1;
mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

Цикл по мини-пакетам и классифицирует изображения с помощью modelPredictions функция, перечисленная в конце примера.

predictions = modelPredictions(dlnet,mbqTest,classes);

Оцените точность классификации.

YTest = imdsValidation.Labels;
accuracy = mean(predictions == YTest)
accuracy = 0.9530

Функция градиентов модели

modelGradients функционируйте берет dlnetwork объект dlnet, мини-пакет входных данных dlX с соответствием маркирует Y и возвращает градиенты потери относительно настраиваемых параметров в dlnet, сетевое состояние и потеря. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция предсказаний модели

modelPredictions функционируйте берет dlnetwork объект dlnet, minibatchqueue из входных данных mbq, и сетевые классы, и вычисляют предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функционируйте, чтобы найти предсказанный класс с самым высоким счетом.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

Мини-функция предварительной обработки пакета

preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток с помощью следующих шагов:

  1. Предварительно обработайте изображения с помощью preprocessMiniBatchPredictors функция.

  2. Извлеките данные о метке из массива входящей ячейки и конкатенируйте в категориальный массив вдоль второго измерения.

  3. Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

Мини-пакетные предикторы, предварительно обрабатывающие функцию

preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из входного массива ячеек, и конкатенируйте в числовой массив. Для полутонового входа, конкатенирующего по четвертой размерности, добавляет третью размерность в каждое изображение, чтобы использовать в качестве одноэлементной размерности канала.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

Входные параметры

свернуть все

Сеть для пользовательских учебных циклов в виде dlnetwork объект.

Входные данные в виде отформатированного dlarray. Для получения дополнительной информации о dlarray форматы, смотрите fmt входной параметр dlarray.

Слои, чтобы извлечь выходные параметры из в виде массива строк или массива ячеек из символьных векторов, содержащего имена слоя.

  • Если layerNames(i) соответствует слою с одним выходом, затем layerNames(i) имя слоя.

  • Если layerNames(i) соответствует слою с несколькими выходными параметрами, затем layerNames(i) имя слоя, сопровождаемое символьным "/"и имя слоя вывело: 'layerName/outputName'.

Оптимизация эффективности в виде одного из следующего:

  • 'auto' — Автоматически примените много оптимизации, подходящей для входной сети и аппаратных ресурсов.

  • 'none' — Отключите все ускорение.

Опцией по умолчанию является 'auto'.

Используя 'auto' ускоряющая опция может предложить выигрыши в производительности, но за счет увеличенного начального времени выполнения. Последующие вызовы совместимыми параметрами быстрее. Используйте оптимизацию эффективности, когда вы запланируете вызвать функцию многократно с помощью различных входных данных с тем же размером и формой.

Выходные аргументы

свернуть все

Выходные данные, возвращенные как отформатированный dlarray. Для получения дополнительной информации о dlarray форматы, смотрите fmt входной параметр dlarray.

Обновленное сетевое состояние, возвращенное как таблица.

Сетевое состояние является таблицей с тремя столбцами:

  • Layer – Имя слоя в виде строкового скаляра.

  • Parameter – Имя параметра состояния в виде строкового скаляра.

  • Value – Значение параметра состояния в виде dlarray объект.

Состояния слоя содержат информацию, вычисленную во время операции слоя, которая будет сохранена для использования в последующих прямых передачах слоя. Например, ячейка утверждают и скрытое состояние слоев LSTM или рабочая статистика в слоях нормализации партии.

Для текущих слоев, таких как слои LSTM, с HasStateInputs набор свойств к 1 (TRUE), таблица состояния не содержит записи для состояний того слоя.

Обновите состояние dlnetwork использование State свойство.

Вопросы совместимости

развернуть все

Поведение изменяется в R2021a

Расширенные возможности

Введенный в R2019b
Для просмотра документации необходимо авторизоваться на сайте