Оцените эффективность ускоренной функции глубокого обучения

В этом примере показано, как оценить увеличение производительности использования ускоренной функции.

При использовании dlfeval функция в пользовательском учебном цикле, программное обеспечение прослеживает каждый вход dlarray объект градиентов модели функционирует, чтобы определить график расчета, используемый для автоматического дифференцирования. Этот процесс трассировки может занять время и может провести время, повторно вычисляя ту же трассировку. Путем оптимизации, кэшируясь и снова используя трассировки, можно ускорить расчет градиента в функциях глубокого обучения. Можно также оптимизировать, кэшироваться, и трассировки повторного использования, чтобы ускорить другие функции глубокого обучения, которые не требуют автоматического дифференцирования, например, можно также ускорить функции модели и функции, используемые для предсказания.

Чтобы ускорить вызовы функций глубокого обучения, используйте dlaccelerate функция, чтобы создать AcceleratedFunction возразите, что automaticallyoptimizes, кэши, и снова используют трассировки. Можно использовать dlaccelerate функция, чтобы ускорить функции модели и градиенты модели функционирует непосредственно, или ускорять подфункции, используемые этими функциями. Увеличение производительности является самым примечательным для более глубоких сетей и учебных циклов со многими эпохами и итерациями.

Возвращенный AcceleratedFunction объектные кэши трассировки вызовов базовой функции и повторных использований кэшируемый результат, когда тот же входной набор повторяется.

Попытайтесь использовать dlaccelerate для вызовов функции, что:

  • продолжительны

  • имейте dlarray объект, структуры dlarray объекты или dlnetwork объекты как входные параметры

  • не имейте побочных эффектов как запись в файлы или отображение вывода

Этот пример сравнивает обучение и времена предсказания при использовании и не использование ускорения.

Загрузите обучение и тестовые данные

digitTrain4DArrayData функционируйте загружает изображения, их метки цифры и их углы вращения от вертикали. Создайте arrayDatastore объекты для изображений, меток и углов, и затем используют combine функция, чтобы сделать один datastore, который содержит все обучающие данные. Извлеките имена классов и количество недискретных ответов.

[imagesTrain,labelsTrain,anglesTrain] = digitTrain4DArrayData;

dsImagesTrain = arrayDatastore(imagesTrain,'IterationDimension',4);
dsLabelsTrain = arrayDatastore(labelsTrain);
dsAnglesTrain = arrayDatastore(anglesTrain);

dsTrain = combine(dsImagesTrain,dsLabelsTrain,dsAnglesTrain);

classNames = categories(labelsTrain);
numClasses = numel(classNames);
numResponses = size(anglesTrain,2);
numObservations = numel(labelsTrain);

Создайте datastore, содержащий тестовые данные, данные digitTest4DArrayData функция с помощью тех же шагов.

[imagesTest,labelsTest,anglesTest] = digitTest4DArrayData;

dsImagesTest = arrayDatastore(imagesTest,'IterationDimension',4);
dsLabelsTest = arrayDatastore(labelsTest);
dsAnglesTest = arrayDatastore(anglesTest);

dsTest = combine(dsImagesTest,dsLabelsTest,dsAnglesTest);

Задайте модель глубокого обучения

Задайте следующую сеть, которая предсказывает и метки и углы вращения.

  • convolution-batchnorm-ReLU блокируется с 16 фильтрами 5 на 5.

  • Ветвь двух блоков свертки-batchnorm каждый с 32 3х3 фильтрами с операцией ReLU между

  • Связь пропуска со сверткой-batchnorm блокируется с 32 свертками 1 на 1.

  • Объедините обе ветви с помощью сложения, сопровождаемого операцией ReLU

  • Для регрессии выход, ветвь с полностью связанной операцией размера 1 (количество ответов).

  • Для классификации выход, ветвь с полностью связанной операцией размера 10 (количество классов) и softmax операцией.

Задайте и инициализируйте параметры модели и состояние

Создайте struct parametersBaseline содержа параметры модели с помощью modelParameters функция, перечисленная в конце примера. modelParameters функция создает структуры parameters и state это содержит инициализированные параметры модели и состояние, соответственно.

Выход использует формат parameters.OperationName.ParameterName где parameters структура, OperationName имя операции (например, "conv1") и ParameterName имя параметра (например, "Веса").

[parametersBaseline,stateBaseline] = modelParameters(numClasses,numResponses);

Создайте копию параметров и состояния для базовой модели, чтобы использовать для ускоренной модели.

parametersAccelerated = parametersBaseline;
stateAccelerated = stateBaseline;

Функция модели Define

Создайте функциональный model, перечисленный в конце примера, который вычисляет выходные параметры модели глубокого обучения, описанной ранее.

Функциональный model берет параметры модели parameters, входные данные dlX, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет параметры модели, мини-пакет входных данных dlX с соответствующими целями T1 и T2 содержание меток и углов, соответственно, и возвращает градиенты потери относительно настраиваемых параметров, обновленного сетевого состояния и соответствующей потери.

Задайте опции обучения

Задайте опции обучения. Обучайтесь в течение 20 эпох с мини-пакетным размером 32. Отображение графика может заставить обучение занять больше времени, чтобы завершиться. Отключите график путем установки plots переменная к "none". Чтобы включить график, установите эту переменную на "training-progress".

numEpochs = 20;
miniBatchSize = 32;
plots = "none";

Обучите базовую модель

Используйте minibatchqueue обработать и управлять мини-пакетами изображений. Для каждого мини-пакета:

  • Используйте пользовательский мини-пакет, предварительно обрабатывающий функциональный preprocessMiniBatch (заданный в конце этого примера) к одногорячему кодируют метки класса.

  • Формат данные изображения с размерностью маркирует 'SSCB' (пространственный, пространственный, канал, пакет). По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Не добавляйте формат в метки класса или углы.

  • Отбросьте любые частичные мини-пакеты, возвращенные в конце эпохи.

  • Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','',''}, ...
    'PartialMiniBatch','discard');

Инициализируйте параметры для Адама.

trailingAvg = [];
trailingAvgSq = [];

При необходимости инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучите модель. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функция.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Обновите график процесса обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches
    while hasdata(mbq)

        iteration = iteration + 1;

        [dlX,dlT1,dlT2] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % model gradients function.
        [gradients,stateBaseline,loss] = dlfeval(@modelGradients,parametersBaseline,dlX,dlT1,dlT2,stateBaseline);

        % Update the network parameters using the Adam optimizer.
        [parametersBaseline,trailingAvg,trailingAvgSq] = adamupdate(parametersBaseline,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            loss = double(gather(extractdata(loss)));
            addpoints(lineLossTrain,iteration,loss)
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
elapsedBaseline = toc(start)
elapsedBaseline = 285.8978

Обучите ускоренную модель

Ускорьте функцию градиентов модели использование dlaccelerate функция.

accfun = dlaccelerate(@modelGradients);

Очистите любые ранее кэшируемые трассировки от ускоренной функции с помощью clearCache функция.

clearCache(accfun)

Инициализируйте параметры для Адама.

trailingAvg = [];
trailingAvgSq = [];

При необходимости инициализируйте график процесса обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Обучите модель с помощью ускоренной функции градиентов модели в вызове dlfeval функция.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches
    while hasdata(mbq)

        iteration = iteration + 1;

        [dlX,dlT1,dlT2] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % accelerated function.
        [gradients,stateAccelerated,loss] = dlfeval(accfun, parametersAccelerated, dlX, dlT1, dlT2, stateAccelerated);

        % Update the network parameters using the Adam optimizer.
        [parametersAccelerated,trailingAvg,trailingAvgSq] = adamupdate(parametersAccelerated,gradients, ...
            trailingAvg,trailingAvgSq,iteration);

        % Display the training progress.
        if plots == "training-progress"
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            loss = double(gather(extractdata(loss)));
            addpoints(lineLossTrain,iteration,loss)
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            drawnow
        end
    end
end
elapsedAccelerated = toc(start)
elapsedAccelerated = 188.5316

Проверяйте КПД ускоренной функции путем осмотра HitRate свойство. HitRate свойство содержит процент вызовов функции, которые снова используют кэшируемую трассировку.

accfun.HitRate
ans = 99.9679

Сравните учебные времена

Сравните учебные времена в столбчатой диаграмме.

figure
bar(categorical(["Baseline" "Accelerated"]),[elapsedBaseline elapsedAccelerated]);
ylabel("Time (seconds)")
title("Training Time")

Вычислите ускорение ускорения.

speedup = elapsedBaseline / elapsedAccelerated
speedup = 1.5164

Базовые предсказания времени

Измерьте время, требуемое сделать предсказания с помощью набора тестовых данных.

После обучения создание предсказаний на новых данных не требует меток. Создайте minibatchqueue объект, содержащий только предикторы тестовых данных:

  • Чтобы проигнорировать метки для тестирования, определите номер выходных параметров мини-пакетной очереди к 1.

  • Задайте тот же мини-пакетный размер, используемый для обучения.

  • Предварительно обработайте предикторы с помощью preprocessMiniBatchPredictors функция, перечисленная в конце примера.

  • Для одного выхода datastore задайте мини-пакетный формат 'SSCB' (пространственный, пространственный, канал, пакет).

numOutputs = 1;
mbqTest = minibatchqueue(dsTest,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

Цикл по мини-пакетам и классифицирует изображения с помощью modelPredictions функция, перечисленная в конце примера и меры прошедшее время.

tic
[labelsPred,anglesPred] = modelPredictions(@model,parametersBaseline,stateBaseline,mbqTest,classNames);
elapsedPredictionBaseline = toc
elapsedPredictionBaseline = 5.5070

Время ускоренные предсказания

Поскольку функция предсказаний модели требует мини-пакетной очереди, как введено, функция не поддерживает ускорение. Чтобы ускорить предсказание, ускорьте функцию модели.

Ускорьте функцию модели с помощью dlaccelerate функция.

accfun2 = dlaccelerate(@model);

Очистите любые ранее кэшируемые трассировки от ускоренной функции с помощью clearCache функция.

clearCache(accfun2)

Сбросьте мини-пакетную очередь.

reset(mbqTest)

Цикл по мини-пакетам и классифицирует изображения с помощью modelPredictions функция, перечисленная в конце примера и меры прошедшее время.

tic
[labelsPred,anglesPred] = modelPredictions(accfun2,parametersBaseline,stateBaseline,mbqTest,classNames);
elapsedPredictionAccelerated = toc
elapsedPredictionAccelerated = 4.3057

Проверяйте КПД ускоренной функции путем осмотра HitRate свойство. HitRate свойство содержит процент вызовов функции, которые снова используют кэшируемую трассировку.

accfun2.HitRate
ans = 98.7261

Сравните времена предсказания

Сравните времена предсказания в столбчатой диаграмме.

figure
bar(categorical(["Baseline" "Accelerated"]),[elapsedPredictionBaseline elapsedPredictionAccelerated]);
ylabel("Time (seconds)")
title("Prediction Time")

Вычислите ускорение ускорения.

speedup = elapsedPredictionBaseline / elapsedPredictionAccelerated
speedup = 1.2790

Функция параметров модели

modelParameters функция создает структуры parameters и state это содержит инициализированные параметры модели и состояние, соответственно для модели, описанной в разделе Define Deep Learning Model. Функция берет в качестве входа количество классов и количество ответов и инициализирует настраиваемые параметры. Функция:

  • инициализирует веса слоя с помощью initializeGlorot функция

  • инициализирует смещения слоя с помощью initializeZeros функция

  • инициализирует смещение нормализации партии. и масштабные коэффициенты с initializeZeros функция

  • инициализирует масштабные коэффициенты нормализации партии. initializeOnes функция

  • инициализирует состояние нормализации партии. обученное среднее значение initializeZeros функция

  • инициализирует состояние нормализации партии. обученное отклонение initializeOnes функция, взятая в качестве примера,

Функции, взятые в качестве примера, инициализации присоединены к этому примеру как к вспомогательным файлам. Чтобы получить доступ к этим файлам, откройте пример как live скрипт. Чтобы узнать больше об инициализации настраиваемых параметров для моделей глубокого обучения, смотрите, Инициализируют Настраиваемые параметры для Функции Модели.

Выход использует формат parameters.OperationName.ParameterName где parameters структура, OperationName имя операции (например, "conv1") и ParameterName имя параметра (например, "Веса").

function [parameters,state] = modelParameters(numClasses,numResponses)

% First convolutional layer.
filterSize = [5 5];
numChannels = 1;
numFilters = 16;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv1.Bias = initializeZeros([numFilters 1]);

% First batch normalization layer.
parameters.batchnorm1.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm1.Scale = initializeOnes([numFilters 1]);
state.batchnorm1.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm1.TrainedVariance = initializeOnes([numFilters 1]);

% Second convolutional layer.
filterSize = [3 3];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv2.Bias = initializeZeros([numFilters 1]);

% Second batch normalization layer.
parameters.batchnorm2.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm2.Scale = initializeOnes([numFilters 1]);
state.batchnorm2.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm2.TrainedVariance = initializeOnes([numFilters 1]);

% Third convolutional layer.
filterSize = [3 3];
numChannels = 32;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv3.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv3.Bias = initializeZeros([numFilters 1]);

% Third batch normalization layer.
parameters.batchnorm3.Offset = initializeZeros([numFilters 1]);
parameters.batchnorm3.Scale = initializeOnes([numFilters 1]);
state.batchnorm3.TrainedMean = initializeZeros([numFilters 1]);
state.batchnorm3.TrainedVariance = initializeOnes([numFilters 1]);

% Convolutional layer in the skip connection.
filterSize = [1 1];
numChannels = 16;
numFilters = 32;

sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.convSkip.Weights = initializeGlorot(sz,numOut,numIn);
parameters.convSkip.Bias = initializeZeros([numFilters 1]);

% Batch normalization layer in the skip connection.
parameters.batchnormSkip.Offset = initializeZeros([numFilters 1]);
parameters.batchnormSkip.Scale = initializeOnes([numFilters 1]);

state.batchnormSkip.TrainedMean = initializeZeros([numFilters 1]);
state.batchnormSkip.TrainedVariance = initializeOnes([numFilters 1]);

% Fully connected layer corresponding to the classification output.
sz = [numClasses 6272];
numOut = numClasses;
numIn = 6272;
parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc1.Bias = initializeZeros([numClasses 1]);

% Fully connected layer corresponding to the regression output.
sz = [numResponses 6272];
numOut = numResponses;
numIn = 6272;
parameters.fc2.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc2.Bias = initializeZeros([numResponses 1]);

end

Функция модели

Функциональный model берет параметры модели parameters, входные данные dlX, флаг doTraining который задает, должен ли к модели возвратить выходные параметры для обучения или предсказания и сетевого state состояния. Сетевые выходные параметры предсказания для меток, предсказания для углов и обновленное сетевое состояние.

function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state)

% Convolution
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias,'Padding','same');

% Batch normalization, ReLU
offset = parameters.batchnorm1.Offset;
scale = parameters.batchnorm1.Scale;
trainedMean = state.batchnorm1.TrainedMean;
trainedVariance = state.batchnorm1.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm1.TrainedMean = trainedMean;
    state.batchnorm1.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution, batch normalization (Skip connection)
weights = parameters.convSkip.Weights;
bias = parameters.convSkip.Bias;
dlYSkip = dlconv(dlY,weights,bias,'Stride',2);

offset = parameters.batchnormSkip.Offset;
scale = parameters.batchnormSkip.Scale;
trainedMean = state.batchnormSkip.TrainedMean;
trainedVariance = state.batchnormSkip.TrainedVariance;

if doTraining
    [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnormSkip.TrainedMean = trainedMean;
    state.batchnormSkip.TrainedVariance = trainedVariance;
else
    dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);
end

% Convolution
weights = parameters.conv2.Weights;
bias = parameters.conv2.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2);

% Batch normalization, ReLU
offset = parameters.batchnorm2.Offset;
scale = parameters.batchnorm2.Scale;
trainedMean = state.batchnorm2.TrainedMean;
trainedVariance = state.batchnorm2.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm2.TrainedMean = trainedMean;
    state.batchnorm2.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

dlY = relu(dlY);

% Convolution
weights = parameters.conv3.Weights;
bias = parameters.conv3.Bias;
dlY = dlconv(dlY,weights,bias,'Padding','same');

% Batch normalization
offset = parameters.batchnorm3.Offset;
scale = parameters.batchnorm3.Scale;
trainedMean = state.batchnorm3.TrainedMean;
trainedVariance = state.batchnorm3.TrainedVariance;

if doTraining
    [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);

    % Update state
    state.batchnorm3.TrainedMean = trainedMean;
    state.batchnorm3.TrainedVariance = trainedVariance;
else
    dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);
end

% Addition, ReLU
dlY = dlYSkip + dlY;
dlY = relu(dlY);

% Fully connect, softmax (labels)
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY1 = fullyconnect(dlY,weights,bias);
dlY1 = softmax(dlY1);

% Fully connect (angles)
weights = parameters.fc2.Weights;
bias = parameters.fc2.Bias;
dlY2 = fullyconnect(dlY,weights,bias);

end

Функция градиентов модели

modelGradients функция, берет параметры модели, мини-пакет входных данных dlX с соответствующими целями T1 и T2 содержание меток и углов, соответственно, и возвращает градиенты потери относительно настраиваемых параметров, обновленного сетевого состояния и соответствующей потери.

function [gradients,state,loss] = modelGradients(parameters,dlX,T1,T2,state)

doTraining = true;
[dlY1,dlY2,state] = model(parameters,dlX,doTraining,state);

lossLabels = crossentropy(dlY1,T1);
lossAngles = mse(dlY2,T2);

loss = lossLabels + 0.1*lossAngles;
gradients = dlgradient(loss,parameters);

end

Функция предсказаний модели

modelPredictions функционируйте берет параметры модели, состояние, minibatchqueue из входных данных mbq, и сетевые классы, и вычисляют предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функционируйте, чтобы найти предсказанный класс с самым высоким счетом.

function [predictions1, predictions2] = modelPredictions(modelFcn,parameters,state,mbq,classes)

doTraining = false;
predictions1 = [];
predictions2 = [];

while hasdata(mbq)

    dlXTest = next(mbq);

    [dlYPred1,dlYPred2] = modelFcn(parameters,dlXTest,doTraining,state);

    YPred1 = onehotdecode(dlYPred1,classes,1)';
    YPred2 = extractdata(dlYPred2)';

    predictions1 = [predictions1; YPred1];
    predictions2 = [predictions2; YPred2];
end

end

Функция предварительной обработки мини-пакета

preprocessMiniBatch функция предварительно обрабатывает данные с помощью следующих шагов:

  1. Извлеките данные изображения из массива входящей ячейки и конкатенируйте в числовой массив. Конкатенация данных изображения по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использоваться в качестве одноэлементной размерности канала.

  2. Извлеките метку и угловые данные из массивов входящей ячейки и конкатенируйте вдоль второго измерения в категориальный массив и числовой массив, соответственно.

  3. Одногорячий кодируют категориальные метки в числовые массивы. Кодирование в первую размерность производит закодированный массив, который совпадает с формой сетевого выхода.

function [X,Y,angle] = preprocessMiniBatch(XCell,YCell,angleCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate
Y = cat(2,YCell{:});

% Extract angle data from cell and concatenate
angle = cat(2,angleCell{:});

% One-hot encode labels
Y = onehotencode(Y,1);

end

Мини-пакетные предикторы, предварительно обрабатывающие функцию

preprocessMiniBatchPredictors функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из входного массива ячеек, и конкатенируйте в числовой массив. Для полутонового входа, конкатенирующего по четвертой размерности, добавляет третью размерность в каждое изображение, чтобы использовать в качестве одноэлементной размерности канала.

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

Смотрите также

| | | | |

Похожие темы