Распознавание активности из данных о видео и оптическом потоке Используя глубокое обучение

В этом примере показано, как обучить 2D потоковую сверточную нейронную сеть Расширенного 3-D (I3D) распознаванию активности с помощью RGB и данных об оптическом потоке из видео [1].

Основанное на видении распознавание активности включает предсказание действия объекта, такого как обход, плавание или нахождение, с помощью набора видеокадров. Распознавание активности от видео имеет много приложений, таких как человеко-машинное взаимодействие, изучение робота, обнаружение аномалии, наблюдение и обнаружение объектов. Например, онлайновое предсказание нескольких действий для входящих видео от нескольких камер может быть важно для изучения робота. Сравненный с классификацией изображений, распознавание действия с помощью видео сложно к модели из-за шумных меток в наборах видеоданных, множестве действий, которые агенты в видео могут выполнить, которые являются в большой степени неустойчивым классом, и вычислить неэффективность в предварительном обучении на больших наборах видеоданных. Некоторые методы глубокого обучения, такие как 2D поток I3D сверточные сети [1], показали улучшенную производительность путем усиления предварительного обучения на больших наборах данных классификации изображений.

Загрузка данных

Этот пример обучает сеть I3D с помощью набора данных HMDB51. Используйте downloadHMDB51 поддерживание функции, перечисленной в конце этого примера, чтобы загрузить набор данных HMDB51 на папку под названием hmdb51.

downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);

После того, как загрузка завершена, извлеките файл RAR hmdb51_org.rar к hmdb51 папка. Затем используйте checkForHMDB51Folder поддерживание функции, перечисленной в конце этого примера, чтобы подтвердить, что загруженные и извлеченные файлы существуют.

allClasses = checkForHMDB51Folder(downloadFolder);

Набор данных содержит приблизительно 2 Гбайт видеоданных для 7 000 клипов более чем 51 класс, такие как напиток, запуск, и обменяться рукопожатием. Каждый видеокадр имеет высоту 240 пикселей и минимальную ширину 176 пикселей. Количество систем координат лежит в диапазоне от 18 до приблизительно 1 000.

Чтобы уменьшать учебное время, этот пример обучает сеть распознавания активности, чтобы классифицировать 5 классов действия вместо всего 51 класса в наборе данных. Установите useAllData к true обучаться со всем 51 классом.

useAllData = false;

if useAllData
    classes = allClasses;
else
    classes = ["kiss","laugh","pick","pour","pushup"];
end
dataFolder = fullfile(downloadFolder, "hmdb51_org");

Разделите набор данных в набор обучающих данных для того, чтобы обучить сеть и набор тестов для оценки сети. Используйте 80% данных для набора обучающих данных и остальных для набора тестов. Используйте imageDatastore разделять данные на основе каждой метки в обучение и тестовые данные устанавливает путем случайного выбора пропорции файлов от каждой метки.

imds = imageDatastore(fullfile(dataFolder,classes),...
    'IncludeSubfolders', true,...
    'LabelSource', 'foldernames',...
    'FileExtensions', '.avi');

[trainImds,testImds] = splitEachLabel(imds,0.8,'randomized');

trainFilenames = trainImds.Files;
testFilenames  = testImds.Files;

Чтобы нормировать входные данные для сети, минимальные и максимальные значения для набора данных введены в файле MAT inputStatistics.mat, присоединенный к этому примеру. Чтобы найти минимальные и максимальные значения для различного набора данных, используйте inputStatistics поддерживание функции, перечисленной в конце этого примера.

inputStatsFilename = 'inputStatistics.mat';
if ~exist(inputStatsFilename, 'file')
    disp("Reading all the training data for input statistics...")
    inputStats = inputStatistics(dataFolder);
else
    d = load(inputStatsFilename);
    inputStats = d.inputStats;    
end

Создайте хранилища данных для того, чтобы обучить нейронные сети

Создайте два FileDatastore объекты для обучения и валидации при помощи createFileDatastore поддерживание функции, заданной в конце этого примера. Каждый datastore читает видеофайл, чтобы обеспечить данные о RGB, данные об оптическом потоке и соответствующую информацию о метке.

Задайте количество систем координат для каждого чтения datastore. Типичные значения равняются 16, 32, 64, или 128. Используя большее количество систем координат помогает получить больше временной информации, но требует большей памяти для обучения и предсказания. Определите номер систем координат к 64, чтобы сбалансировать использование памяти относительно эффективности. Вы можете должны быть понизить это значение в зависимости от своих системных ресурсов.

numFrames = 64;

Задайте высоту и ширину систем координат для datastore, чтобы читать. Фиксация высоты и ширины к тому же значению делает данные о пакетной обработке для сети легче. Типичные значения [112, 112], [224, 224], и [256, 256]. Минимальная высота и ширина видеокадров в наборе данных HMDB51 240 и 176, соответственно. Задайте [112, 112], чтобы получить большее число систем координат за счет пространственной информации. Если вы хотите задать формат кадра для datastore, чтобы считать, что больше, чем минимальные значения, такой, когда [256, 256], сначала измените размер систем координат с помощью imresize.

frameSize = [112,112];

Установите inputSize к inputStats структура так функция чтения fileDatastore может считать заданный входной размер.

inputSize = [frameSize, numFrames];
inputStats.inputSize = inputSize;
inputStats.Classes = classes;

Создайте два FileDatastore объекты, один для обучения и другого для валидации.

isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);

isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);

disp("Training data size: " + string(numel(dsTrain.Files)))
Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))
Validation data size: 109

Архитектура сети Define

Сеть I3D

Используя 3-D CNN естественный подход к извлечению пространственно-временных функций от видео. Можно создать сеть I3D из предварительно обученной 2D сети классификации изображений, такой как Начало v1 или ResNet-50 путем расширения 2D фильтров и объединения ядер в 3-D. Эта процедура снова использует веса, усвоенные из задачи классификации изображений загрузить видео задачу распознавания.

Следующая фигура является выборкой, показывающей, как раздуть 2D слой свертки к 3-D слою свертки. Инфляция включает расширение размера фильтра, весов и смещения путем добавления третьей размерности (временная размерность).

2D поток сеть I3D

Видеоданные, как может рассматриваться, имеют две части: пространственный компонент и временный компонент.

  • Пространственный компонент включает информацию о форме, структуре и цвете объектов в видео. Данные о RGB содержат эту информацию.

  • Временный компонент включает информацию о движении объектов через системы координат и изображает важные перемещения между камерой и объектами в сцене. Вычисление оптического потока является общим методом для извлечения временной информации от видео.

2D потоковый CNN включает пространственную подсеть и временную подсеть [2]. Сверточная нейронная сеть, обученная на плотном оптическом потоке и потоке видеоданных, может достигнуть лучшей эффективности с ограниченными обучающими данными, чем со сложенными системами координат RGB сырых данных. Следующий рисунок показывает типичную 2D потоковую сеть I3D.

Создайте 2D поток сеть I3D

В этом примере вы создаете использование сети I3D GoogLeNet, сеть, предварительно обученная на базе данных ImageNet.

Задайте количество каналов как 3 для подсети RGB и 2 для подсети оптического потока. Два канала для данных об оптическом потоке x и y компоненты скорости, Vx и Vy, соответственно.

rgbChannels = 3;
flowChannels = 2;

Получите минимальные и максимальные значения для RGB и данных об оптическом потоке из inputStats структура загружена от inputStatistics.mat файл. Эти значения необходимы для image3dInputLayer из сетей I3D, чтобы нормировать входные данные.

rgbInputSize = [frameSize, numFrames, rgbChannels];
flowInputSize = [frameSize, numFrames, flowChannels];

rgbMin = inputStats.rgbMin;
rgbMax = inputStats.rgbMax;
oflowMin = inputStats.oflowMin(:,:,1:2);
oflowMax = inputStats.oflowMax(:,:,1:2);

rgbMin = reshape(rgbMin,[1,size(rgbMin)]);
rgbMax = reshape(rgbMax,[1,size(rgbMax)]);
oflowMin = reshape(oflowMin,[1,size(oflowMin)]);
oflowMax = reshape(oflowMax,[1,size(oflowMax)]);

Задайте количество классов для того, чтобы обучить сеть.

numClasses = numel(classes);

Создайте I3D RGB и подсети оптического потока при помощи Inflated3D поддерживание функции, которая присоединена к этому примеру. Подсети создаются из GoogLeNet.

cnnNet = googlenet;

netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet);
netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);

Создайте dlnetwork объект от графика слоев каждой из сетей I3D.

dlnetRGB = dlnetwork(netRGB);
dlnetFlow = dlnetwork(netFlow);

Функция градиентов модели Define

Создайте функцию поддержки modelGradients, перечисленный в конце этого примера. modelGradients функционируйте берет в качестве входа подсеть RGB dlnetRGB, подсеть оптического потока dlnetFlow, мини-пакет входных данных dlRGB и dlFlow, и мини-пакет основной истины помечает данные dlY. Функция возвращает учебное значение потерь, градиенты потери относительно настраиваемых параметров соответствующих подсетей и мини-пакетной точности подсетей.

Потеря вычисляется путем вычисления среднего значения перекрестных энтропийных потерь предсказаний от каждой из подсетей. Выходные предсказания сети являются вероятностями между 0 и 1 для каждого из классов.

rgbLoss=crossentropy(rgbPrediction)

flowLoss=crossentropy(flowPrediction)

loss=mean([rgbLoss,flowLoss])

Точность каждой из подсетей вычисляется путем взятия среднего значения RGB и предсказаний оптического потока, и сравнения его с меткой основной истины входных параметров.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 20 для 1 500 итераций. Задайте итерацию, после которой можно сохранить модель с лучшей точностью валидации при помощи SaveBestAfterIteration параметр.

Задайте отжигающие косинус параметры расписания [3] скорости обучения. Для обеих сетей используйте:

  • Минимальная скорость обучения 1e-4.

  • Максимальная скорость обучения 1e-3.

  • Количество косинуса итераций 300, 500, и 700, после которого цикл расписания скорости обучения перезапускает. Опция CosineNumIterations задает ширину каждого цикла косинуса.

Задайте параметры для оптимизации SGDM. Инициализируйте параметры оптимизации SGDM в начале обучения каждой из сетей оптического потока и RGB. Для обеих сетей используйте:

  • Импульс 0,9.

  • Начальный скоростной параметр, инициализированный как [].

  • Фактор регуляризации L2 0,0005.

Задайте, чтобы диспетчеризировать данные в фоновом режиме с помощью параллельного пула. Если DispatchInBackground установлен в истинный, откройте параллельный пул с конкретным количеством параллельных рабочих и создайте DispatchInBackgroundDatastore, если как часть этого примера, который диспетчеризирует данные в фоновом режиме, чтобы ускорить обучение с помощью асинхронной загрузки данных и предварительной обработки. По умолчанию этот пример использует графический процессор, если вы доступны. В противном случае это использует центральный процессор. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения информации о поддерживаемом вычислите возможности, смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

params.Classes = classes;
params.MiniBatchSize = 20;
params.NumIterations = 1500;
params.SaveBestAfterIteration = 900;
params.CosineNumIterations = [300, 500, 700];
params.MinLearningRate = 1e-4;
params.MaxLearningRate = 1e-3;
params.Momentum = 0.9;
params.VelocityRGB = [];
params.VelocityFlow = [];
params.L2Regularization = 0.0005;
params.ProgressPlot = false;
params.Verbose = true;
params.ValidationData = dsVal;
params.DispatchInBackground = false;
params.NumWorkers = 4;

Обучение сети

Обучите подсети с помощью потока данных RGB и данных об оптическом потоке. Установите doTraining переменная к false загружать предварительно обученные подсети, не имея необходимость ожидать обучения завершиться. В качестве альтернативы, если вы хотите обучить подсети, установите doTraining переменная к true.

doTraining = false;

В течение каждой эпохи:

  • Переставьте данные перед цикличным выполнением по мини-пакетам данных.

  • Используйте minibatchqueue циклично выполняться по мини-пакетам. Функция поддержки createMiniBatchQueue, перечисленный в конце этого примера, использует данный учебный datastore, чтобы создать minibatchqueue.

  • Используйте данные о валидации dsVal проверять сети.

  • Отобразите результаты потери и точности за каждую эпоху с помощью функции поддержки displayVerboseOutputEveryEpoch, перечисленный в конце этого примера.

Для каждого мини-пакета:

  • Преобразуйте данные изображения или данные об оптическом потоке и метки к dlarray объекты с базовым одним типом.

  • Обработайте временную размерность данные о видео и оптическом потоке как одна из пространственных размерностей, чтобы позволить обработать использование 3-D CNN. Укажите, что размерность маркирует "SSSCB" (пространственный, пространственный, пространственный, канал, пакет) для RGB или данных об оптическом потоке и "CB" для данных о метке.

minibatchqueue возразите использует функцию поддержки batchRGBAndFlow, перечисленный в конце этого примера, чтобы обработать в пакетном режиме RGB и данные об оптическом потоке.

modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat";
if doTraining 
    epoch = 1;
    bestValAccuracy = 0;
    accTrain = [];
    accTrainRGB = [];
    accTrainFlow = [];
    lossTrain = [];
        
    iteration = 1;
    shuffled = shuffleTrainDs(dsTrain);
    
    % Number of outputs is three: One for RGB frames, one for optical flow
    % data, and one for ground truth labels.
    numOutputs = 3;
    mbq = createMiniBatchQueue(shuffled, numOutputs, params);
    start = tic;
    trainTime = start;
    
    % Use the initializeTrainingProgressPlot and initializeVerboseOutput
    % supporting functions, listed at the end of the example, to initialize
    % the training progress plot and verbose output to display the training
    % loss, training accuracy, and validation accuracy.
    plotters = initializeTrainingProgressPlot(params);
    initializeVerboseOutput(params);
    
    while iteration <= params.NumIterations

        % Iterate through the data set.
        [dlX1,dlX2,dlY] = next(mbq);

        % Evaluate the model gradients and loss using dlfeval.
        [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ...
            dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);
        
        % Accumulate the loss and accuracies.
        lossTrain = [lossTrain, loss];
        accTrain = [accTrain, acc];
        accTrainRGB = [accTrainRGB, accRGB];
        accTrainFlow = [accTrainFlow, accFlow];
        % Update the network state.
        dlnetRGB.State = stateRGB;
        dlnetFlow.State = stateFlow;
        
        % Update the gradients and parameters for the RGB and optical flow
        % subnetworks using the SGDM optimizer.
        [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ...
            updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration);
        [dlnetFlow,gradFlow,params.VelocityFlow] = ...
            updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration);
        
        if ~hasdata(mbq) || iteration == params.NumIterations
            % Current epoch is complete. Do validation and update progress.
            trainTime = toc(trainTime);

            [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ...
                doValidation(params, dlnetRGB, dlnetFlow);

            % Update the training progress.
            displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
                mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),...
                accValidation,accValidationRGB,accValidationFlow,...
                mean(lossTrain),lossValidation,trainTime,validationTime);
            updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation);
            
            % Save model with the trained dlnetwork and accuracy values.
            % Use the saveData supporting function, listed at the
            % end of this example.
            if iteration >= params.SaveBestAfterIteration
                if accValidation > bestValAccuracy
                    bestValAccuracy = accValidation;
                    saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation);
                end
            end
        end
        
        if ~hasdata(mbq) && iteration < params.NumIterations
            % Current epoch is complete. Initialize the training loss, accuracy
            % values, and minibatchqueue for the next epoch.
            accTrain = [];
            accTrainRGB = [];
            accTrainFlow = [];
            lossTrain = [];
        
            trainTime = tic;
            epoch = epoch + 1;
            shuffled = shuffleTrainDs(dsTrain);
            numOutputs = 3;
            mbq = createMiniBatchQueue(shuffled, numOutputs, params);
            
        end 
        
        iteration = iteration + 1;
    end
    
    % Display a message when training is complete.
    endVerboseOutput(params);
    
    disp("Model saved to: " + modelFilename);
end

% Download the pretrained model and video file for prediction.
filename = "activityRecognition-I3D-HMDB51.zip";
downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename;

filename = fullfile(downloadFolder,filename);
if ~exist(filename,'file')
    disp('Downloading the pretrained network...');
    websave(filename,downloadURL);
end
% Unzip the contents to the download folder.
unzip(filename,downloadFolder);
if ~doTraining
    modelFilename = fullfile(downloadFolder, modelFilename);
end

Оцените обученную сеть

Используйте набор тестовых данных, чтобы оценить точность обученных подсетей.

Загрузите лучшую модель, сохраненную во время обучения.

d = load(modelFilename);
dlnetRGB = d.data.dlnetRGB;
dlnetFlow = d.data.dlnetFlow;

Создайте minibatchqueue возразите, чтобы загрузить пакеты тестовых данных.

numOutputs = 3;
mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

Для каждого пакета тестовых данных сделайте предсказания с помощью RGB и сетей оптического потока, возьмите среднее значение предсказаний и вычислите точность предсказания с помощью матрицы беспорядка.

cmat = sparse(numClasses,numClasses);
while hasdata(mbq)
    [dlRGB, dlFlow, dlY] = next(mbq);
    
    % Pass the video input as RGB and optical flow data through the
    % two-stream subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    
    % Calculate the accuracy of the predictions.
    [~,YTest] = max(dlY,[],1);
    [~,YPred] = max(dlYPred,[],1);

    cmat = aggregateConfusionMetric(cmat,YTest,YPred);
end

Вычислите среднюю точность классификации для обучившего нейронные сети.

accuracyEval = sum(diag(cmat))./sum(cmat,"all")
accuracyEval = 
      0.60909

Отобразите матрицу неточностей.

figure
chart = confusionchart(cmat,classes);

Из-за ограниченного количества обучающих выборок, увеличивая точность вне 61% сложно. Чтобы улучшить робастность сети, дополнительное обучение с большим набором данных требуется. Кроме того, предварительное обучение на большем наборе данных, таком как кинетика [1], может помочь улучшить результаты.

Предскажите Используя новое видео

Можно теперь использовать обучивший нейронные сети, чтобы предсказать действия в новых видео. Считайте и отобразите видео pour.avi использование VideoReader и vision.VideoPlayer.

videoFilename = fullfile(downloadFolder, "pour.avi");

videoReader = VideoReader(videoFilename);
videoPlayer = vision.VideoPlayer;
videoPlayer.Name = "pour";

while hasFrame(videoReader)
   frame = readFrame(videoReader);
   step(videoPlayer,frame);
end
release(videoPlayer);

Используйте readRGBAndFlow поддерживание функции, перечисленной в конце этого примера, чтобы считать RGB и данные об оптическом потоке.

isDataForValidation = true;
readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);

Функция чтения возвращает логический isDone значение, которое указывает, существует ли больше данных, чтобы читать из файла. Используйте batchRGBAndFlow поддерживание функции, заданной в конце этого примера, чтобы обработать данные в пакетном режиме, чтобы пройти через 2D потоковые подсети, чтобы получить предсказания.

hasdata = true;
userdata = [];
YPred = [];
while hasdata
    [data,userdata,isDone] = readFcn(videoFilename,userdata);
    
    [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3));
    
    % Pass video input as RGB and optical flow data through the two-stream
    % subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    [~,YPredCurr] = max(dlYPred,[],1);
    YPred = horzcat(YPred,YPredCurr);
    hasdata = ~isDone;
end
YPred = extractdata(YPred);

Считайте количество правильных предсказаний с помощью histcounts, и получите предсказанное действие с помощью максимального количества правильных предсказаний.

classes = params.Classes;
counts = histcounts(YPred,1:numel(classes));
[~,clsIdx] = max(counts);
action = classes(clsIdx)
action = 
"pour"

Вспомогательные Функции

inputStatistics

inputStatistics функционируйте берет в качестве входа имя папки, содержащей данные HMDB51, и вычисляет минимальные и максимальные значения для данных о RGB и данных об оптическом потоке. Минимальные и максимальные значения используются в качестве входных параметров нормализации к входному слою сетей. Эта функция также получает количество систем координат в каждом из видеофайлов, чтобы использовать позже во время обучения и тестирования сети. Для того, чтобы найти минимальные и максимальные значения для различного набора данных, используйте эту функцию с именем папки, содержащим набор данных.

function inputStats = inputStatistics(dataFolder)
    ds = createDatastore(dataFolder);
    ds.ReadFcn = @getMinMax;

    tic;
    tt = tall(ds);
    varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'};
    stats = gather(groupsummary(tt,[],{'max','min'}, varnames));
    inputStats.Filename = gather(tt.Filename);
    inputStats.NumFrames = gather(tt.NumFrames);
    inputStats.rgbMax = stats.max_rgbMax;
    inputStats.rgbMin = stats.min_rgbMin;
    inputStats.oflowMax = stats.max_oflowMax;
    inputStats.oflowMin = stats.min_oflowMin;
    save('inputStatistics.mat','inputStats');
    toc;
end

function data = getMinMax(filename)
    reader = VideoReader(filename);
    opticFlow = opticalFlowFarneback;
    data = [];
    while hasFrame(reader)
        frame = readFrame(reader);
        [rgb,oflow] = findMinMax(frame,opticFlow);
        data = assignMinMax(data, rgb, oflow);
    end

    totalFrames = floor(reader.Duration * reader.FrameRate);
    totalFrames = min(totalFrames, reader.NumFrames);
    
    [labelName, filename] = getLabelFilename(filename);
    data.Filename = fullfile(labelName, filename);
    data.NumFrames = totalFrames;

    data = struct2table(data,'AsArray',true);
end

function data = assignMinMax(data, rgb, oflow)
    if isempty(data)
        data.rgbMax = rgb.Max;
        data.rgbMin = rgb.Min;
        data.oflowMax = oflow.Max;
        data.oflowMin = oflow.Min;
        return;
    end
    data.rgbMax = max(data.rgbMax, rgb.Max);
    data.rgbMin = min(data.rgbMin, rgb.Min);

    data.oflowMax = max(data.oflowMax, oflow.Max);
    data.oflowMin = min(data.oflowMin, oflow.Min);
end

function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow)
    rgbMinMax.Max = max(rgb,[],[1,2]);
    rgbMinMax.Min = min(rgb,[],[1,2]);

    gray = rgb2gray(rgb);
    flow = estimateFlow(opticFlow,gray);
    oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude);

    oflowMinMax.Max = max(oflow,[],[1,2]);
    oflowMinMax.Min = min(oflow,[],[1,2]);
end

function ds = createDatastore(folder)    
    ds = fileDatastore(folder,...
        'IncludeSubfolders', true,...
        'FileExtensions', '.avi',...
        'UniformRead', true,...
        'ReadFcn', @getMinMax);
    disp("NumFiles: " + numel(ds.Files));
end

createFileDatastore

createFileDatastore функция создает FileDatastore объект с помощью данных имен файлов. FileDatastore объект считывает данные в 'partialfile' режим, таким образом, каждое чтение может возвратить частично системы координат чтения в видео. Эта функция помогает с чтением больших видеофайлов, если все системы координат не умещаются в памяти.

function datastore = createFileDatastore(filenames,inputStats,isDataForValidation)
    readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
    datastore = fileDatastore(filenames,...
        'ReadFcn',readFcn,...
        'ReadMode','partialfile');
end

readRGBAndFlow

readRGBAndFlow функционируйте системы координат RGB чтений, соответствующие данные об оптическом потоке и значения метки для данного видеофайла. Во время обучения функция чтения читает определенное количество систем координат согласно сетевому входному размеру со случайным образом выбранной стартовой системой координат. Данные об оптическом потоке вычислены с начала видеофайла, но пропущены, пока стартовая система координат не достигнута. Во время тестирования последовательно читаются все системы координат, и соответствующие данные об оптическом потоке вычисляются. Системы координат RGB и данные об оптическом потоке случайным образом обрезаются к необходимому сетевому входному размеру для обучения и центру, обрезанному для тестирования и валидации.

function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation)
    if isempty(userdata)
        userdata.reader      = VideoReader(filename);
        userdata.batchesRead = 0;
        userdata.opticalFlow = opticalFlowFarneback;
        
        [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename);
        if isempty(totalFrames)
            totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate);
            totalFrames = min(totalFrames, userdata.reader.NumFrames);
        end
        userdata.totalFrames = totalFrames;
    end
    reader      = userdata.reader;
    totalFrames = userdata.totalFrames;
    label       = userdata.label;
    batchesRead = userdata.batchesRead;
    opticalFlow = userdata.opticalFlow;

    inputSize = inputStats.inputSize;
    H = inputSize(1);
    W = inputSize(2);
    rgbC = 3;
    flowC = 2;
    numFrames = inputSize(3);

    if numFrames > totalFrames
        numBatches = 1;
    else
        numBatches = floor(totalFrames/numFrames);
    end

    imH = userdata.reader.Height;
    imW = userdata.reader.Width;
    imsz = [imH,imW];

    if ~isDataForValidation

        augmentFcn = augmentTransform([imsz,3]);
        cropWindow = randomCropWindow2d(imsz, inputSize(1:2));
        %  1. Randomly select required number of frames,
        %     starting randomly at a specific frame.
        if numFrames >= totalFrames
            idx = 1:totalFrames;
            % Add more frames to fill in the network input size.
            additional = ceil(numFrames/totalFrames);
            idx = repmat(idx,1,additional);
            idx = idx(1:numFrames);
        else
            startIdx = randperm(totalFrames - numFrames);
            startIdx = startIdx(1);
            endIdx = startIdx + numFrames - 1;
            idx = startIdx:endIdx;
        end

        video = zeros(H,W,rgbC,numFrames);
        oflow = zeros(H,W,flowC,numFrames);
        i = 1;
        % Discard the first set of frames to initialize the optical flow.
        for ii = 1:idx(1)-1
            frame = read(reader,ii);
            getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
        end
        % Read the next set of required number of frames for training.
        for ii = idx
            frame = read(reader,ii);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
    else
        augmentFcn = @(data)(data);
        cropWindow = centerCropWindow2d(imsz, inputSize(1:2));
        toRead = min([numFrames,totalFrames]);
        video = zeros(H,W,rgbC,toRead);
        oflow = zeros(H,W,flowC,toRead);
        i = 1;
        while hasFrame(reader) && i <= numFrames
            frame = readFrame(reader);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
        if numFrames > totalFrames
            additional = ceil(numFrames/totalFrames);
            video = repmat(video,1,1,1,additional);
            oflow = repmat(oflow,1,1,1,additional);
            video = video(:,:,:,1:numFrames);
            oflow = oflow(:,:,:,1:numFrames);
        end
    end

    % The network expects the video and optical flow input in 
    % the following dlarray format: 
    % "SSSCB" ==> Height x Width x Frames x Channels x Batch
    %
    % Permute the data 
    %  from
    %      Height x Width x Channels x Frames
    %  to 
    %      Height x Width x Frames x Channels
    video = permute(video, [1,2,4,3]);
    oflow = permute(oflow, [1,2,4,3]);

    data = {video, oflow, label};

    batchesRead = batchesRead + 1;

    userdata.batchesRead = batchesRead;

    % Set the done flag to true, if the reader has read all the frames or
    % if it is training.
    done = batchesRead == numBatches || ~isDataForValidation;
end

function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow)
    rgb = augmentFcn(rgb);
    gray = rgb2gray(rgb);
    flow = estimateFlow(opticalFlow,gray);
    vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy);

    rgb = imcrop(rgb, cropWindow);
    vxvy = imcrop(vxvy, cropWindow);
    vxvy = vxvy(:,:,1:2);
end

function [label,fname] = getLabelFilename(filename)
    [folder,name,ext] = fileparts(string(filename));
    [~,label] = fileparts(folder);
    fname = name + ext;
    label = string(label);
    fname = string(fname);
end

function [totalFrames,label] = getTotalFramesAndLabel(info, filename)
    filenames = info.Filename;
    frames = info.NumFrames;
    [labelName, fname] = getLabelFilename(filename);
    idx = strcmp(filenames, fullfile(labelName,fname));
    totalFrames = frames(idx);    
    label = categorical(string(labelName), string(info.Classes));
end

augmentTransform

augmentTransform функция создает метод увеличения со случайным зеркальным отражением лево-права и масштабными коэффициентами.

function augmentFcn = augmentTransform(sz)
% Randomly flip and scale the image.
tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput');

augmentFcn = @(data)augmentData(data,tform,rout);

    function data = augmentData(data,tform,rout)
        data = imwarp(data,tform,'OutputView',rout);
    end
end

modelGradients

modelGradients функционируйте берет в качестве входа мини-пакет данных о RGB dlRGB, соответствующие данные об оптическом потоке dlFlow, и соответствующий целевой dlY, и возвращает соответствующую потерю, градиенты потери относительно настраиваемых параметров и учебную точность. Чтобы вычислить градиенты, оцените modelGradients функция с помощью dlfeval функция в учебном цикле.

function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass video input as RGB and optical flow data through the two-stream
% network.
[dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB);
[dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow);

% Calculate fused loss, gradients, and accuracy for the two-stream
% predictions.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);
% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

gradientsRGB = dlgradient(loss,dlnetRGB.Learnables);
gradientsFlow = dlgradient(loss,dlnetFlow.Learnables);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

acc = gather(extractdata(sum(YTest == YPred)./numel(YTest)));

% Calculate the accuracy of the RGB and flow predictions.
[~,YTest] = max(Y,[],1);
[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest)));
accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest)));
end

doValidation

doValidation функция проверяет сеть с помощью данных о валидации.

function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow)

    validationTime = tic;

    numOutputs = 3;
    mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

    lossValidation = [];
    numClasses = numel(params.Classes);
    cmat = sparse(numClasses,numClasses);
    cmatRGB = sparse(numClasses,numClasses);
    cmatFlow = sparse(numClasses,numClasses);
    while hasdata(mbq)

        [dlX1,dlX2,dlY] = next(mbq);

        [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);

        lossValidation = [lossValidation,loss];
        cmat = aggregateConfusionMetric(cmat,YTest,YPred);
        cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB);
        cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow);
    end
    lossValidation = mean(lossValidation);
    accValidation = sum(diag(cmat))./sum(cmat,"all");
    accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all");
    accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all");

    validationTime = toc(validationTime);
end

predictValidation

predictValidation функция вычисляет потерю и значения предсказания с помощью обеспеченного dlnetwork объекты для RGB и данных об оптическом потоке.

function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass the video input through the two-stream
% network.
dlYPredRGB = predict(dlnetRGB,dlRGB);
dlYPredFlow = predict(dlnetFlow,dlFlow);

% Calculate the cross-entropy separately for the two-stream
% outputs.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);

% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

end

updateDlnetwork

updateDlnetwork функционируйте обновляет обеспеченный dlnetwork объект с градиентами и другими параметрами с помощью оптимизационной функции SGDM sgdmupdate.

function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration)
    % Determine the learning rate using the cosine-annealing learning rate schedule.
    learnRate = cosineAnnealingLearnRate(iteration, params);

    % Apply L2 regularization to the weights.
    idx = dlnet.Learnables.Parameter == "Weights";
    gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

    % Update the network parameters using the SGDM optimizer.
    [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum);
end

cosineAnnealingLearnRate

cosineAnnealingLearnRate функция вычисляет скорость обучения на основе текущего номера итерации, минимальную скорость обучения, максимальную скорость обучения и количество итераций для отжига [3].

function lr = cosineAnnealingLearnRate(iteration, params)
    if iteration == params.NumIterations
        lr = params.MinLearningRate;
        return;
    end
    cosineNumIter = [0, params.CosineNumIterations];
    csum = cumsum(cosineNumIter);
    block = find(csum >= iteration, 1,'first');
    cosineIter = iteration - csum(block - 1);
    annealingIteration = mod(cosineIter, cosineNumIter(block));
    cosineIteration = cosineNumIter(block);
    minR = params.MinLearningRate;
    maxR = params.MaxLearningRate;
    cosMult = 1 + cos(pi * annealingIteration / cosineIteration);
    lr = minR + ((maxR - minR) *  cosMult / 2);
end

aggregateConfusionMetric

aggregateConfusionMetric функционируйте инкрементно заполняет матрицу беспорядка на основе предсказанных результатов YPred и ожидаемые результаты YTest.

function cmat = aggregateConfusionMetric(cmat,YTest,YPred)
YTest = gather(extractdata(YTest));
YPred = gather(extractdata(YPred));
[m,n] = size(cmat);
cmat = cmat + full(sparse(YTest,YPred,1,m,n));
end

createMiniBatchQueue

createMiniBatchQueue функция создает minibatchqueue объект, который обеспечивает miniBatchSize объем данных от данного datastore. Это также создает DispatchInBackgroundDatastore если параллельный пул открыт.

function mbq = createMiniBatchQueue(datastore, numOutputs, params)
if params.DispatchInBackground && isempty(gcp('nocreate'))
    % Start a parallel pool, if DispatchInBackground is true, to dispatch
    % data in the background using the parallel pool.
    c = parcluster('local');
    c.NumWorkers = params.NumWorkers;
    parpool('local',params.NumWorkers);
end
p = gcp('nocreate');
if ~isempty(p)
    datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers);
end
inputFormat(1:numOutputs-1) = "SSSCB";
outputFormat = "CB";
mbq = minibatchqueue(datastore, numOutputs, ...
    "MiniBatchSize", params.MiniBatchSize, ...
    "MiniBatchFcn", @batchRGBAndFlow, ...
    "MiniBatchFormat", [inputFormat,outputFormat]);
end

batchRGBAndFlow

batchRGBAndFlow функционируйте обрабатывает в пакетном режиме изображение, поток и данные о метке в соответствующий dlarray значения в форматах данных "SSSCB", "SSSCB", и "CB", соответственно.

function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels)
% Batch dimension: 5
X1 = cat(5,images{:});
X2 = cat(5,flows{:});

% Batch dimension: 2
labels = cat(2,labels{:});

% Feature dimension: 1
Y = onehotencode(labels,1);

% Cast data to single for processing.
X1 = single(X1);
X2 = single(X2);
Y = single(Y);

% Move data to the GPU if possible.
if canUseGPU
    X1 = gpuArray(X1);
    X2 = gpuArray(X2);
    Y = gpuArray(Y);
end

% Return X and Y as dlarray objects.
dlX1 = dlarray(X1,"SSSCB");
dlX2 = dlarray(X2,"SSSCB");
dlY = dlarray(Y,"CB");
end

shuffleTrainDs

shuffleTrainDs функционируйте переставляет файлы, существующие в учебном datastore dsTrain.

function shuffled = shuffleTrainDs(dsTrain)
shuffled = copy(dsTrain);
n = numel(shuffled.Files);
shuffledIndices = randperm(n);
shuffled.Files = shuffled.Files(shuffledIndices);
reset(shuffled);
end

saveData

saveData функция сохраняет данный dlnetwork объекты и значения точности к файлу MAT.

function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation)
dlnetRGB = gatherFromGPUToSave(dlnetRGB);
dlnetFlow = gatherFromGPUToSave(dlnetFlow);
data.ValidationAccuracy = accValidation;
data.cmat = cmat;
data.dlnetRGB = dlnetRGB;
data.dlnetFlow = dlnetFlow;
save(modelFilename, 'data');
end

gatherFromGPUToSave

gatherFromGPUToSave функция собирает данные из графического процессора для того, чтобы сохранить модель на диск.

function dlnet = gatherFromGPUToSave(dlnet)
if ~canUseGPU
    return;
end
dlnet.Learnables = gatherValues(dlnet.Learnables);
dlnet.State = gatherValues(dlnet.State);
    function tbl = gatherValues(tbl)
        for ii = 1:height(tbl)
            tbl.Value{ii} = gather(tbl.Value{ii});
        end
    end
end

checkForHMDB51Folder

checkForHMDB51Folder функционируйте проверки на загруженные данные в папке загрузки.

function classes = checkForHMDB51Folder(dataLoc)
hmdbFolder = fullfile(dataLoc, "hmdb51_org");
if ~exist(hmdbFolder, "dir")
    error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");    
end

classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",...
    "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",...
    "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",...
    "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",...
    "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",...
    "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",...
    "sword_exercise","talk","throw","turn","walk","wave"];
expectFolders = fullfile(hmdbFolder, classes);
if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders))
    error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");
end
end

downloadHMDB51

downloadHMDB51 функционируйте загружает набор данных и сохраняет его в директорию.

function downloadHMDB51(dataLoc)

if nargin == 0
    dataLoc = pwd;
end
dataLoc = string(dataLoc);

if ~exist(dataLoc,"dir")
    mkdir(dataLoc);
end

dataUrl     = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar";
options     = weboptions('Timeout', Inf);
rarFileName = fullfile(dataLoc, 'hmdb51_org.rar');
fileExists  = exist(rarFileName, 'file');

% Download the RAR file and save it to the download folder.
if ~fileExists
    disp("Downloading hmdb51_org.rar (2 GB) to the folder:")
    disp(dataLoc)
    disp("This download can take a few minutes...") 
    websave(rarFileName, dataUrl, options); 
    disp("Download complete.")
    disp("Extract the hmdb51_org.rar file contents to the folder: ") 
    disp(dataLoc)
end
end

initializeTrainingProgressPlot

initializeTrainingProgressPlot функция конфигурирует два графика для отображения учебной потери, учебной точности и точности валидации.

function plotters = initializeTrainingProgressPlot(params)
if params.ProgressPlot
    % Plot the loss, training accuracy, and validation accuracy.
    figure
    
    % Loss plot
    subplot(2,1,1)
    plotters.LossPlotter = animatedline;
    xlabel("Iteration")
    ylabel("Loss")
    
    % Accuracy plot
    subplot(2,1,2)
    plotters.TrainAccPlotter = animatedline('Color','b');
    plotters.ValAccPlotter = animatedline('Color','g');
    legend('Training Accuracy','Validation Accuracy','Location','northwest');
    xlabel("Iteration")
    ylabel("Accuracy")
else
    plotters = [];
end
end

initializeVerboseOutput

initializeVerboseOutput функционируйте отображает заголовки столбцов для таблицы учебных значений, которая показывает эпоху, мини-пакетную точность и другие учебные значения.

function initializeVerboseOutput(params)
if params.Verbose
    disp(" ")
    if canUseGPU
        disp("Training on GPU.")
    else
        disp("Training on CPU.")
    end
    p = gcp('nocreate');
    if ~isempty(p)
        disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ")
    end
    disp("NumIterations:" + string(params.NumIterations));
    disp("MiniBatchSize:" + string(params.MiniBatchSize));
    disp("Classes:" + join(string(params.Classes), ","));    
    disp("|=======================================================================================================================================================================|")
    disp("| Epoch | Iteration | Time Elapsed |     Mini-Batch Accuracy    |    Validation Accuracy     | Mini-Batch | Validation |  Base Learning  | Train Time | Validation Time |")
    disp("|       |           |  (hh:mm:ss)  |       (Avg:RGB:Flow)       |       (Avg:RGB:Flow)       |    Loss    |    Loss    |      Rate       | (hh:mm:ss) |   (hh:mm:ss)    |")
    disp("|=======================================================================================================================================================================|")
end
end

displayVerboseOutputEveryEpoch

displayVerboseOutputEveryEpoch функция отображает многословный вывод учебных значений, таких как эпоха, мини-пакетная точность, точность валидации и мини-пакетная потеря.

function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
        accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime)
    if params.Verbose
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        trainTime = duration(0,0,trainTime,'Format','hh:mm:ss');
        validationTime = duration(0,0,validationTime,'Format','hh:mm:ss');

        lossValidation = gather(extractdata(lossValidation));
        lossValidation = compose('%.4f',lossValidation);

        accValidation = composePadAccuracy(accValidation);
        accValidationRGB = composePadAccuracy(accValidationRGB);
        accValidationFlow = composePadAccuracy(accValidationFlow);

        accVal = join([accValidation,accValidationRGB,accValidationFlow], " : ");

        lossTrain = gather(extractdata(lossTrain));
        lossTrain = compose('%.4f',lossTrain);

        accTrain = composePadAccuracy(accTrain);
        accTrainRGB = composePadAccuracy(accTrainRGB);
        accTrainFlow = composePadAccuracy(accTrainFlow);

        accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : ");
        learnRate = compose('%.13f',learnRate);

        disp("| " + ...
            pad(string(epoch),5,'both') + " | " + ...
            pad(string(iteration),9,'both') + " | " + ...
            pad(string(D),12,'both') + " | " + ...
            pad(string(accTrain),26,'both') + " | " + ...
            pad(string(accVal),26,'both') + " | " + ...
            pad(string(lossTrain),10,'both') + " | " + ...
            pad(string(lossValidation),10,'both') + " | " + ...
            pad(string(learnRate),13,'both') + " | " + ...
            pad(string(trainTime),10,'both') + " | " + ...
            pad(string(validationTime),15,'both') + " |")
    end

end

function acc = composePadAccuracy(acc)
    acc = compose('%.2f',acc*100) + "%";
    acc = pad(string(acc),6,'left');
end

endVerboseOutput

endVerboseOutput функционируйте отображает конец многословного выхода во время обучения.

function endVerboseOutput(params)
if params.Verbose
    disp("|=======================================================================================================================================================================|")        
end
end

updateProgressPlot

updateProgressPlot функционируйте обновляет график прогресса с информацией о потере и точности во время обучения.

function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation)
if params.ProgressPlot
    
    % Update the training progress.
    D = duration(0,0,toc(start),"Format","hh:mm:ss");
    title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D));
    addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain))));
    addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain);
    addpoints(plotters.ValAccPlotter,iteration,accuracyValidation);
    drawnow
end
end

Ссылки

[1] Carreira, Жоао и Эндрю Зиссермен. "Quo Vadis, распознавание действия? Новая модель и набор данных кинетики". Продолжения конференции по IEEE по компьютерному зрению и распознаванию образов (CVPR): 6299?? 6308. Гонолулу, HI: IEEE, 2017.

[2] Симонян, Карен и Эндрю Зиссермен. "2D поток сверточные сети для распознавания действия в видео". Усовершенствования в нейронных системах обработки информации 27, Лонг-Бич, CA: ЗАЖИМЫ, 2017.

[3] Лощилов, Илья и Франк Хуттер. "SGDR: стохастический градиентный спуск с горячими перезапусками". Международный Conferencee на изучении представлений 2017. Тулон, Франция: ICLR, 2017.

Для просмотра документации необходимо авторизоваться на сайте