exponenta event banner

Оценка производительности функции ускоренного глубокого обучения

В этом примере показано, как оценить прирост производительности при использовании ускоренной функции.

При использовании dlfeval функция в индивидуальном учебном цикле, программное обеспечение отслеживает каждый вход dlarray объект функции градиентов модели для определения расчетного графика, используемого для автоматического дифференцирования. Этот процесс трассировки может занять некоторое время и может потратить время на повторное вычисление одной и той же трассировки. Оптимизация, кэширование и повторное использование трасс позволяют ускорить вычисления градиента в функциях глубокого обучения. Можно также оптимизировать, кэшировать и повторно использовать трассировки для ускорения других функций глубокого обучения, которые не требуют автоматического дифференцирования, например, можно также ускорить функции модели и функции, используемые для прогнозирования.

Для ускорения вызовов функций глубокого обучения используйте dlaccelerate для создания функции AcceleratedFunction объект, который автоматически оптимизирует, кэширует и повторно использует трассировку. Вы можете использовать dlaccelerate функция для непосредственного ускорения функций модели и функций градиентов модели или для ускорения субфункций, используемых этими функциями. Повышение производительности наиболее заметно для более глубоких сетей и циклов обучения со многими эпохами и итерациями.

Возвращенный AcceleratedFunction объект кэширует следы вызовов базовой функции и повторно использует кэшированный результат при повторном возникновении одного и того же шаблона ввода.

Попробуйте использовать dlaccelerate для вызовов функций, которые:

  • являются длительными

  • имеют dlarray объект, конструкции dlarray объекты, или dlnetwork объекты в качестве входных данных

  • не имеют побочных эффектов, таких как запись в файлы или отображение выходных данных

В этом примере сравниваются времена обучения и прогнозирования при использовании, а не при использовании ускорения.

Данные по обучению и тестированию нагрузки

digitTrain4DArrayData функция загружает изображения, их цифровые метки и углы поворота от вертикали. Создать arrayDatastore объекты для изображений, меток и углов, а затем используйте combine для создания единого хранилища данных, содержащего все данные обучения. Извлеките имена классов и количество недискретных ответов.

[imagesTrain,labelsTrain,anglesTrain] = digitTrain4DArrayData;

dsImagesTrain = arrayDatastore(imagesTrain,'IterationDimension',4);
dsLabelsTrain = arrayDatastore(labelsTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);

dsTrain = combine(dsImagesTrain,dsLabelsTrain,dsAnglesTrain);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numResponses = size(anglesTrain,2);
numObservations = numel(labelsTrain);

Создание хранилища данных, содержащего тестовые данные, заданные digitTest4DArrayData с использованием тех же шагов.

[imagesTest,labelsTest,anglesTest] = digitTest4DArrayData;

dsImagesTest = arrayDatastore(imagesTest,'IterationDimension',4);
dsLabelsTest = arrayDatastore(labelsTest);
dsAnglesTest = arrayDatastore(anglesTest);

dsTest = combine(dsImagesTest,dsLabelsTest,dsAnglesTest);

Определение модели глубокого обучения

Определите следующую сеть, которая предсказывает как метки, так и углы поворота.

  • Блок свертки-дозирования-ReLU с 16 фильтрами 5 на 5.

  • Ветвь из двух блоков свертки-последовательности с 32 фильтрами 3 на 3 с операцией ReLU между

  • Соединение пропуска с блоком свертки-последовательности с 32 свертками 1 на 1.

  • Объединение обеих ветвей с помощью добавления с последующей операцией ReLU

  • Для выхода регрессии ветвь с полностью связанной операцией размера 1 (количество откликов).

  • Для вывода классификации ветвь с полностью связанной операцией размера 10 (количество классов) и операцией softmax.

Определение и инициализация параметров и состояния модели

Создание структуры parametersBaseline содержащий параметры модели с использованием modelParameters функция, перечисленная в конце примера. modelParameters функция создает структуры parameters и state , которые содержат параметры инициализированной модели и состояние соответственно.

В выходных данных используется формат parameters.OperationName.ParameterName где parameters - структура, OperationName - имя операции (например, «conv1») и ParameterName - имя параметра (например, «Веса»).

[parametersBaseline,stateBaseline] = modelParameters(numClasses,numResponses);

Создайте копию параметров и состояния базовой модели для использования в ускоренной модели.

parametersAccelerated = parametersBaseline;
stateAccelerated = stateBaseline;

Определение функции модели

Создание функции model, перечисленных в конце примера, который вычисляет выходные данные модели глубокого обучения, описанной ранее.

Функция model принимает параметры модели parameters, входные данные dlX, флаг doTraining который определяет, должна ли модель возвращать выходные данные для обучения или прогнозирования, и состояние сети state. Сеть выводит прогнозы для меток, прогнозы для углов и обновленное состояние сети.

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в конце примера, который принимает параметры модели, мини-пакет входных данных dlX с соответствующими целями T1 и T2 содержит метки и углы соответственно и возвращает градиенты потерь относительно обучаемых параметров, обновленного состояния сети и соответствующих потерь.

Укажите параметры обучения

Укажите параметры обучения. Поезд на 20 эпох с размером мини-партии 32. Отображение графика может занять больше времени. Отключите график, установив plots переменная для "none". Чтобы включить печать, задайте для этой переменной значение "training-progress".

numEpochs = 20;
miniBatchSize = 32;
plots = "none";

Модель базовой линии поезда

Использовать minibatchqueue для обработки и управления мини-партиями изображений. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в конце этого примера) для одноконтактного кодирования меток класса.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный). По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат к меткам класса или углам.

  • Удалите любые частичные мини-партии, возвращенные в конце эпохи.

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','',''}, ...
    'PartialMiniBatch','discard');

Инициализация параметров для Adam.

trailingAvg = [];
trailingAvgSq = [];

При необходимости инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Тренируйте модель. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:

  • Оценка градиентов и потерь модели с помощью dlfeval и modelGradients функция.

  • Обновление параметров сети с помощью adamupdate функция.

  • Обновите график хода обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches
    while hasdata(mbq)

        iteration = iteration + 1;

        [dlX,dlT1,dlT2] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % model gradients function.
        [gradients,stateBaseline,loss] = dlfeval(@modelGradients,parametersBaseline,dlX,dlT1,dlT2,stateBaseline);

        % Update the network parameters using the Adam optimizer.
        [parametersBaseline,trailingAvg,trailingAvgSq] = adamupdate(parametersBaseline,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            loss = double(gather(extractdata(loss)));
            addpoints(lineLossTrain,iteration,loss)
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
elapsedBaseline = toc(start)
elapsedBaseline = 285.8978

Ускоренная модель поезда

Ускорение функции градиентов модели с помощью dlaccelerate функция.

accfun = dlaccelerate(@modelGradients);

Удаление всех ранее кэшированных трассировок ускоренной функции с помощью clearCache функция.

clearCache(accfun)

Инициализация параметров для Adam.

trailingAvg = [];
trailingAvgSq = [];

При необходимости инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучение модели с использованием функции ускоренных градиентов модели при вызове dlfeval функция.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches
    while hasdata(mbq)

        iteration = iteration + 1;

        [dlX,dlT1,dlT2] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % accelerated function.
        [gradients,stateAccelerated,loss] = dlfeval(accfun, parametersAccelerated, dlX, dlT1, dlT2, stateAccelerated);

        % Update the network parameters using the Adam optimizer.
        [parametersAccelerated,trailingAvg,trailingAvgSq] = adamupdate(parametersAccelerated,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            loss = double(gather(extractdata(loss)));
            addpoints(lineLossTrain,iteration,loss)
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
elapsedAccelerated = toc(start)
elapsedAccelerated = 188.5316

Проверьте эффективность ускоренной функции, проверив HitRate собственность. HitRate содержит процент вызовов функций, которые повторно используют кэшированную трассировку.

accfun.HitRate
ans = 99.9679

Сравнение времени обучения

Сравните время обучения на гистограмме.

figure
bar(categorical(["Baseline" "Accelerated"]),[elapsedBaseline elapsedAccelerated]);
ylabel("Time (seconds)")
title("Training Time")

Рассчитайте ускорение ускорения.

speedup = elapsedBaseline / elapsedAccelerated
speedup = 1.5164

Прогнозы опорной линии времени

Измерьте время, необходимое для прогнозирования с помощью набора тестовых данных.

После обучения составление прогнозов по новым данным не требует меток. Создать minibatchqueue объект, содержащий только предикторы тестовых данных:

  • Чтобы игнорировать метки для тестирования, установите число выходов мини-пакетной очереди равным 1.

  • Укажите размер мини-партии, используемый для обучения.

  • Предварительная обработка предикторов с помощью preprocessMiniBatchPredictors функция, перечисленная в конце примера.

  • Для одиночного вывода хранилища данных укажите формат мини-пакета. 'SSCB' (пространственный, пространственный, канальный, пакетный).

numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

Закольцовывайте мини-пакеты и классифицируйте изображения с помощью modelPredictions функция, перечисленная в конце примера и измеряющая прошедшее время.

tic
[labelsPred,anglesPred] = modelPredictions(@model,parametersBaseline,stateBaseline,mbqTest,classNames);
elapsedPredictionBaseline = toc
elapsedPredictionBaseline = 5.5070

Прогнозы с ускорением по времени

Поскольку функция прогнозирования модели требует в качестве входных данных мини-пакетной очереди, функция не поддерживает ускорение. Чтобы ускорить прогнозирование, ускорьте функцию модели.

Ускорение функции модели с помощью dlaccelerate функция.

accfun2 = dlaccelerate(@model);

Удаление всех ранее кэшированных трассировок ускоренной функции с помощью clearCache функция.

clearCache(accfun2)

Сбросьте очередь мини-пакета.

reset(mbqTest)

Закольцовывайте мини-пакеты и классифицируйте изображения с помощью modelPredictions функция, перечисленная в конце примера и измеряющая прошедшее время.

tic
[labelsPred,anglesPred] = modelPredictions(accfun2,parametersBaseline,stateBaseline,mbqTest,classNames);
elapsedPredictionAccelerated = toc
elapsedPredictionAccelerated = 4.3057

Проверьте эффективность ускоренной функции, проверив HitRate собственность. HitRate содержит процент вызовов функций, которые повторно используют кэшированную трассировку.

accfun2.HitRate
ans = 98.7261

Сравнить времена прогнозирования

Сравните время прогнозирования на гистограмме.

figure
bar(categorical(["Baseline" "Accelerated"]),[elapsedPredictionBaseline elapsedPredictionAccelerated]);
ylabel("Time (seconds)")
title("Prediction Time")

Рассчитайте ускорение ускорения.

speedup = elapsedPredictionBaseline / elapsedPredictionAccelerated
speedup = 1.2790

Функция параметров модели

modelParameters функция создает структуры parameters и state , которые содержат параметры и состояние инициализированной модели, соответственно, для модели, описанной в разделе Определение модели глубокого обучения. Функция принимает в качестве входных данных количество классов и количество ответов и инициализирует обучаемые параметры. Функция:

  • инициализирует веса слоев с помощью initializeGlorot функция

  • инициализирует смещения слоя с помощью initializeZeros функция

  • инициализирует параметры смещения и масштабирования нормализации пакета с помощью initializeZeros функция

  • инициализирует параметры шкалы нормализации пакета с помощью initializeOnes функция

  • инициализирует состояние нормализации партии обученное среднее с помощью initializeZeros функция

  • инициализирует настроенную дисперсию состояния нормализации пакета с помощью initializeOnes примерная функция

Функции примера инициализации присоединены к этому примеру как вспомогательные файлы. Чтобы получить доступ к этим файлам, откройте пример как живой сценарий. Дополнительные сведения о инициализации обучаемых параметров для моделей глубокого обучения см. в разделе Инициализация обучаемых параметров для функции модели.

В выходных данных используется формат parameters.OperationName.ParameterName где parameters - структура, OperationName - имя операции (например, «conv1») и ParameterName - имя параметра (например, «Веса»).

function [parameters,state] = modelParameters(numClasses,numResponses)

% First convolutional layer.
filterSize = [5 5];
numChannels = 1;
numFilters = 16;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv1.Bias = initializeZeros([numFilters 1]);

% First batch normalization layer.
parameters.batchnorm1.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm1.Scale = initializeOnes([numFilters 1]);
state.batchnorm1.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm1.TrainedVariance = initializeOnes([numFilters 1]);

% Second convolutional layer.
filterSize = [3 3];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv2.Bias = initializeZeros([numFilters 1]);

% Second batch normalization layer.
parameters.batchnorm2.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm2.Scale = initializeOnes([numFilters 1]);
state.batchnorm2.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm2.TrainedVariance = initializeOnes([numFilters 1]);

% Third convolutional layer.
filterSize = [3 3];
numChannels = 32;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv3.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv3.Bias = initializeZeros([numFilters 1]);

% Third batch normalization layer.
parameters.batchnorm3.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm3.Scale = initializeOnes([numFilters 1]);
state.batchnorm3.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm3.TrainedVariance = initializeOnes([numFilters 1]);

% Convolutional layer in the skip connection.
filterSize = [1 1];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn);
parameters.convSkip.Bias = initializeZeros([numFilters 1]);

% Batch normalization layer in the skip connection.
parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]);
parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]);

state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]);
state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);

% Fully connected layer corresponding to the classification output.
sz = [numClasses 6272];
numOut = numClasses;
numIn = 6272;
parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc1.Bias = initializeZeros([numClasses 1]);

% Fully connected layer corresponding to the regression output.
sz = [numResponses 6272];
numOut = numResponses;
numIn = 6272;
parameters.fc2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc2.Bias = initializeZeros([numResponses 1]);

end

Функция модели

Функция model принимает параметры модели parameters, входные данные dlX, флаг doTraining который определяет, должна ли модель возвращать выходные данные для обучения или прогнозирования, и состояние сети state. Сеть выводит прогнозы для меток, прогнозы для углов и обновленное состояние сети.

function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state)

% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias,'Padding','same');

% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution, batch normalization (Skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,weights,bias,'Stride',2);

offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
end

% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2);

% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same');

% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

% Addition, ReLU
dlY = dlYSkip + dlY;
dlY = relu(dlY);

% Fully connect, softmax (labels)
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY1 = fullyconnect(dlY,weights,bias);
dlY1 = softmax(dlY1);

% Fully connect (angles)
weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
dlY2 = fullyconnect(dlY,weights,bias);

end

Функция градиентов модели

modelGradients функция, принимает параметры модели, мини-пакет входных данных dlX с соответствующими целями T1 и T2 содержит метки и углы соответственно и возвращает градиенты потерь относительно обучаемых параметров, обновленного состояния сети и соответствующих потерь.

function [gradients,state,loss] = modelGradients(parameters,dlX,T1,T2,state)

doTraining = true;
[dlY1,dlY2,state] = model(parameters,dlX,doTraining,state);

lossLabels = crossentropy(dlY1,T1);
lossAngles = mse(dlY2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,parameters);

end

Функция прогнозирования модели

modelPredictions функция принимает параметры модели, состояние, minibatchqueue входных данных mbqи сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска прогнозируемого класса с наивысшим баллом.

function [predictions1, predictions2] = modelPredictions(modelFcn,parameters,state,mbq,classes)

doTraining = false;
predictions1 = [];
predictions2 = [];

while hasdata(mbq)

    dlXTest = next(mbq);

    [dlYPred1,dlYPred2] = modelFcn(parameters,dlXTest,doTraining,state);

    YPred1 = onehotdecode(dlYPred1,classes,1)';
    YPred2 = extractdata(dlYPred2)';

    predictions1 = [predictions1; YPred1];
    predictions2 = [predictions2; YPred2];
end

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в числовой массив. Конкатенация данных изображения над четвертым размером добавляет к каждому изображению третий размер, который будет использоваться в качестве размера одиночного канала.

  2. Извлеките данные метки и угла из входящих массивов ячеек и объедините их вдоль второго размера в категориальный массив и числовой массив соответственно.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,Y,angle] = preprocessMiniBatch(XCell,YCell,angleCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate
Y = cat(2,YCell{:});

% Extract angle data from cell and concatenate
angle = cat(2,angleCell{:});

% One-hot encode labels
Y = onehotencode(Y,1);

end

Функция предварительной обработки мини-пакетных предикторов

preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из массива входных ячеек и конкатенации в числовой массив. Для ввода в оттенках серого при конкатенации над четвертым размером к каждому изображению добавляется третий размер для использования в качестве размера одиночного канала.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

См. также

| | | | |

Связанные темы