Распознавание активности от видео и данных оптического потока с помощью глубокого обучения

В этом примере показано, как обучить двухпоточную сверточную нейронную сеть 3-D (I3D) для распознавания активности с помощью RGB и данных оптического потока из видео [1].

Распознавание активности на основе зрения включает предсказание действия объекта, такого как ходьба, плавание или сидение, с помощью набора видеокадров. Распознавание активности из видео имеет много приложений, таких как взаимодействие человека с компьютером, обучение робота, обнаружение аномалий, наблюдение и обнаружение объектов. Для примера онлайн- предсказания нескольких действий для входящих видео с нескольких камер могут быть важны для обучения роботов. По сравнению с классификацией изображений распознавание действий с использованием видео трудно смоделировать из-за шумных меток в наборах данных видео, разнообразия действий, которые могут выполнять актёры в видео, которые являются сильно несбалансированными по классам, и вычислительной неэффективности предварительного обучения на больших наборах данных видео. Некоторые методы глубокого обучения, такие как I3D двухпотоковых сверточных сетей [1], показали улучшенную производительность, используя предварительное обучение на больших наборах данных классификации изображений.

Загрузка данных

Этот пример обучает I3D сеть, используя HMDB51 набор данных. Используйте downloadHMDB51 вспомогательная функция, перечисленная в конце этого примера, для загрузки HMDB51 набора данных в папку с именем hmdb51.

downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);

После завершения загрузки извлеките файл RAR hmdb51_org.rar на hmdb51 папка. Далее используйте checkForHMDB51Folder вспомогательная функция, перечисленная в конце этого примера, чтобы подтвердить, что загруженные и извлеченные файлы на месте.

allClasses = checkForHMDB51Folder(downloadFolder);

Набор данных содержит около 2 ГБ видеоданных для 7000 клипов в 51 классе, таких как напиток, беги и пожимай руки. Каждый видеокадр имеет высоту 240 пикселей и минимальную ширину 176 пикселей. Количество систем координат варьируется от 18 до приблизительно 1000.

Чтобы уменьшить время обучения, этот пример обучает сеть распознавания активности классифицировать 5 классов действий вместо всех 51 классов в наборе данных. Задайте useAllData на true для обучения со всеми 51 классами.

useAllData = false;

if useAllData
    classes = allClasses;
else
    classes = ["kiss","laugh","pick","pour","pushup"];
end
dataFolder = fullfile(downloadFolder, "hmdb51_org");

Разделите набор данных на набор обучающих данных для обучения сети и тестовый набор для оценки сети. Используйте 80% данных для набора обучающих данных, а остальное для тестового набора. Использование imageDatastore разделить данные, основанные на каждой метке, на наборы обучающих и тестовых данных путем случайного выбора части файлов из каждой метки.

imds = imageDatastore(fullfile(dataFolder,classes),...
    'IncludeSubfolders', true,...
    'LabelSource', 'foldernames',...
    'FileExtensions', '.avi');

[trainImds,testImds] = splitEachLabel(imds,0.8,'randomized');

trainFilenames = trainImds.Files;
testFilenames  = testImds.Files;

Для нормализации входных данных для сети в файле MAT приведены минимальное и максимальное значения для набора данных inputStatistics.mat, прилагаемый к этому примеру. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте inputStatistics вспомогательная функция, перечисленная в конце этого примера.

inputStatsFilename = 'inputStatistics.mat';
if ~exist(inputStatsFilename, 'file')
    disp("Reading all the training data for input statistics...")
    inputStats = inputStatistics(dataFolder);
else
    d = load(inputStatsFilename);
    inputStats = d.inputStats;    
end

Создайте хранилища данных для обучения сетей

Создайте два FileDatastore объекты для обучения и валидации при помощи createFileDatastore вспомогательная функция, заданная в конце этого примера. Каждый datastore считывает видео- файл, чтобы предоставить данные RGB, данные оптического потока и соответствующую информацию о метке.

Укажите количество систем координат для каждого чтения datastore. Типичными значениями являются 16, 32, 64 или 128. Использование большего количества систем координат помогает захватывать больше временной информации, но требует большей памяти для обучения и предсказания. Установите количество систем координат равное 64, чтобы сбалансировать использование памяти с эффективностью. Может потребоваться снизить это значение в зависимости от системных ресурсов.

numFrames = 64;

Задайте высоту и ширину систем координат для считывания datastore. Фиксация высоты и ширины к одному и тому же значению облегчает пакетную обработку данных для сети. Типичными значениями являются [112, 112], [224, 224] и [256, 256]. Минимальные высота и ширина видеокадров в HMDB51 наборе данных составляют 240 и 176, соответственно. Определите [112, 112], чтобы захватить большее число систем координат за счет пространственной информации. Если вы хотите задать формат кадра для чтения datastore, который больше минимальных значений, таких как [256, 256], сначала измените размер систем координат с помощью imesize.

frameSize = [112,112];

Задайте inputSize на inputStats структура, таким образом, функция read от fileDatastore можно считать заданный размер входного сигнала.

inputSize = [frameSize, numFrames];
inputStats.inputSize = inputSize;
inputStats.Classes = classes;

Создайте два FileDatastore объекты, один для обучения и другой для валидации.

isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);

isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);

disp("Training data size: " + string(numel(dsTrain.Files)))
Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))
Validation data size: 109

Определение сетевой архитектуры

I3D сеть

Использование 3-D CNN является естественным подходом к извлечению пространственно-временных функций из видео. Можно создать I3D сеть из предварительно обученной сети классификации 2-D изображений, такой как Inception v1 или ResNet-50, путем расширения 2-D фильтров и объединения ядер в 3-D. Эта процедура повторно использует веса, полученные из задачи классификации изображений, чтобы загрузить задачу распознавания видео.

Следующий рисунок является выборкой, показывающим, как надуть слой свертки 2-D на слой свертки 3-D. Инфляция включает расширение размера, весов и смещения фильтра путем добавления третьего измерения (временной размерности).

Двухпоточная I3D сеть

Видео данных может быть рассмотрено как имеющее две части: пространственный компонент и временный компонент.

  • Пространственный компонент содержит информацию о форме, текстуре и цвете объектов в видео. Данные RGB содержат эту информацию.

  • Временной компонент содержит информацию о движении объектов через системы координат и изображает важные перемещения между камерой и объектами в сцене. Вычисление оптического потока является общим методом для извлечения временной информации из видео.

Двухпотоковый CNN включает пространственную подсеть и временную подсеть [2]. Сверточная нейронная сеть, обученная на плотном оптическом потоке и потоке видеоданных, может достичь лучшей эффективности с ограниченными обучающими данными, чем с необработанными сложенными системами координат RGB. Следующий рисунок показывает типовую двухпоточную I3D сеть.

Создайте двухпоточную I3D сеть

В этом примере вы создаете I3D сеть с помощью GoogLeNet, сети, предварительно обученной в базе данных ImageNet.

Укажите количество каналов следующим 3 для подсети RGB и 2 для подсети оптического потока. Два канала для данных оптического потока являются x и y составляющие скорости, Vx и Vy, соответственно.

rgbChannels = 3;
flowChannels = 2;

Получите минимальное и максимальное значения для RGB и данных оптического потока от inputStats структура, загруженная из inputStatistics.mat файл. Эти значения необходимы для image3dInputLayer сетей I3D для нормализации входных данных.

rgbInputSize = [frameSize, numFrames, rgbChannels];
flowInputSize = [frameSize, numFrames, flowChannels];

rgbMin = inputStats.rgbMin;
rgbMax = inputStats.rgbMax;
oflowMin = inputStats.oflowMin(:,:,1:2);
oflowMax = inputStats.oflowMax(:,:,1:2);

rgbMin = reshape(rgbMin,[1,size(rgbMin)]);
rgbMax = reshape(rgbMax,[1,size(rgbMax)]);
oflowMin = reshape(oflowMin,[1,size(oflowMin)]);
oflowMax = reshape(oflowMax,[1,size(oflowMax)]);

Укажите количество классов для обучения сети.

numClasses = numel(classes);

Создайте I3D подсети RGB и оптического потока при помощи Inflated3D вспомогательная функция, которая присоединена к этому примеру. Подсети создаются из GoogLeNet.

cnnNet = googlenet;

netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet);
netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);

Создайте dlnetwork объект из графика слоев каждой из I3D сетей.

dlnetRGB = dlnetwork(netRGB);
dlnetFlow = dlnetwork(netFlow);

Задайте функцию градиентов модели

Создайте вспомогательную функцию modelGradients, перечисленный в конце этого примера. The modelGradients функция принимает за вход подсеть RGB dlnetRGB, подсеть оптического потока dlnetFlowмини-пакет входных данных dlRGB и dlFlowи мини-пакет основных истин данных о метках dlY. Функция возвращает значение потерь обучения, градиенты потерь относительно настраиваемых параметров соответствующих подсетей и мини-пакетную точность подсетей.

Потеря вычисляется путем вычисления среднего значения потерь перекрестной энтропии предсказаний от каждой из подсетей. Выходы предсказания сети являются вероятностями от 0 до 1 для каждого из классов.

rgbLoss=crossentropy(rgbPrediction)

flowLoss=crossentropy(flowPrediction)

loss=mean([rgbLoss,flowLoss])

Точность каждой из подсетей вычисляется путем взятия среднего значения RGB и оптического потока предсказаний и сравнения его с основной истиной меткой входов.

Настройка опций обучения

Обучите с мини-пакетом размером 20 для 1500 итераций. Задайте итерацию, после которой можно сохранить модель с лучшей точностью валидации при помощи SaveBestAfterIteration параметр.

Задайте параметры расписания скорости обучения с отжигом косинуса [3]. Для обеих сетей используйте:

  • Минимальная скорость обучения 1e-4.

  • Максимальная скорость обучения 1e-3.

  • Количество косинусов итераций 300, 500 и 700, после чего цикл расписания скорости обучения перезапускается. Опция CosineNumIterations задает ширину каждого цикла косинуса.

Задайте параметры для оптимизации SGDM. Инициализируйте параметры оптимизации SGDM в начале обучения для каждой из сетей RGB и оптического потока. Для обеих сетей используйте:

  • Импульс 0,9.

  • Начальный параметр скорости инициализирован как [].

  • Коэффициент регуляризации L2 0,0005.

Задайте, чтобы отправлять данные в фоновом режиме с помощью параллельного пула. Если DispatchInBackground установлено значение true, откройте параллельный пул с заданным количеством параллельных рабочих мест и создайте DispatchInBackgroundDatastore, представленный в качестве части этого примера, который отправляет данные в фоновом режиме, чтобы ускорить обучение с использованием асинхронной загрузки данных и предварительной обработки. По умолчанию этот пример использует графический процессор, если он доступен. В противном случае используется центральный процессор. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения информации о поддерживаемых вычислительных возможностях смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

params.Classes = classes;
params.MiniBatchSize = 20;
params.NumIterations = 1500;
params.SaveBestAfterIteration = 900;
params.CosineNumIterations = [300, 500, 700];
params.MinLearningRate = 1e-4;
params.MaxLearningRate = 1e-3;
params.Momentum = 0.9;
params.VelocityRGB = [];
params.VelocityFlow = [];
params.L2Regularization = 0.0005;
params.ProgressPlot = false;
params.Verbose = true;
params.ValidationData = dsVal;
params.DispatchInBackground = false;
params.NumWorkers = 4;

Обучите сеть

Обучите подсети с помощью данных RGB и данных оптического потока. Установите doTraining переменная в false загрузить предварительно обученные подсети, не дожидаясь завершения обучения. Кроме того, если необходимо обучить подсети, установите doTraining переменная в true.

doTraining = false;

Для каждой эпохи:

  • Перетащите данные перед закольцовыванием по мини-пакетам данных.

  • Использование minibatchqueue для закольцовывания мини-пакетов. Вспомогательная функция createMiniBatchQueue, перечисленный в конце этого примера, использует данный обучающий datastore для создания minibatchqueue.

  • Используйте данные валидации dsVal для проверки сетей.

  • Отобразите результаты потерь и точности для каждой эпохи с помощью вспомогательной функции displayVerboseOutputEveryEpoch, перечисленный в конце этого примера.

Для каждого мини-пакета:

  • Преобразуйте данные изображения или данные оптического потока и метки в dlarray объекты с базовым типом single.

  • Обработайте временную размерность видео и данных оптического потока как одну из пространственных размерностей, чтобы позволить обработку с помощью 3-D CNN. Задайте метки размерности "SSSCB" (пространственный, пространственный, пространственный, канал, пакет) для RGB или данных оптического потока и "CB" для данных о метке.

The minibatchqueue объект использует вспомогательную функцию batchRGBAndFlow, перечисленный в конце этого примера, чтобы упаковать данные RGB и оптического потока.

modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat";
if doTraining 
    epoch = 1;
    bestValAccuracy = 0;
    accTrain = [];
    accTrainRGB = [];
    accTrainFlow = [];
    lossTrain = [];
        
    iteration = 1;
    shuffled = shuffleTrainDs(dsTrain);
    
    % Number of outputs is three: One for RGB frames, one for optical flow
    % data, and one for ground truth labels.
    numOutputs = 3;
    mbq = createMiniBatchQueue(shuffled, numOutputs, params);
    start = tic;
    trainTime = start;
    
    % Use the initializeTrainingProgressPlot and initializeVerboseOutput
    % supporting functions, listed at the end of the example, to initialize
    % the training progress plot and verbose output to display the training
    % loss, training accuracy, and validation accuracy.
    plotters = initializeTrainingProgressPlot(params);
    initializeVerboseOutput(params);
    
    while iteration <= params.NumIterations

        % Iterate through the data set.
        [dlX1,dlX2,dlY] = next(mbq);

        % Evaluate the model gradients and loss using dlfeval.
        [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ...
            dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);
        
        % Accumulate the loss and accuracies.
        lossTrain = [lossTrain, loss];
        accTrain = [accTrain, acc];
        accTrainRGB = [accTrainRGB, accRGB];
        accTrainFlow = [accTrainFlow, accFlow];
        % Update the network state.
        dlnetRGB.State = stateRGB;
        dlnetFlow.State = stateFlow;
        
        % Update the gradients and parameters for the RGB and optical flow
        % subnetworks using the SGDM optimizer.
        [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ...
            updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration);
        [dlnetFlow,gradFlow,params.VelocityFlow] = ...
            updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration);
        
        if ~hasdata(mbq) || iteration == params.NumIterations
            % Current epoch is complete. Do validation and update progress.
            trainTime = toc(trainTime);

            [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ...
                doValidation(params, dlnetRGB, dlnetFlow);

            % Update the training progress.
            displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
                mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),...
                accValidation,accValidationRGB,accValidationFlow,...
                mean(lossTrain),lossValidation,trainTime,validationTime);
            updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation);
            
            % Save model with the trained dlnetwork and accuracy values.
            % Use the saveData supporting function, listed at the
            % end of this example.
            if iteration >= params.SaveBestAfterIteration
                if accValidation > bestValAccuracy
                    bestValAccuracy = accValidation;
                    saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation);
                end
            end
        end
        
        if ~hasdata(mbq) && iteration < params.NumIterations
            % Current epoch is complete. Initialize the training loss, accuracy
            % values, and minibatchqueue for the next epoch.
            accTrain = [];
            accTrainRGB = [];
            accTrainFlow = [];
            lossTrain = [];
        
            trainTime = tic;
            epoch = epoch + 1;
            shuffled = shuffleTrainDs(dsTrain);
            numOutputs = 3;
            mbq = createMiniBatchQueue(shuffled, numOutputs, params);
            
        end 
        
        iteration = iteration + 1;
    end
    
    % Display a message when training is complete.
    endVerboseOutput(params);
    
    disp("Model saved to: " + modelFilename);
end

% Download the pretrained model and video file for prediction.
filename = "activityRecognition-I3D-HMDB51.zip";
downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename;

filename = fullfile(downloadFolder,filename);
if ~exist(filename,'file')
    disp('Downloading the pretrained network...');
    websave(filename,downloadURL);
end
% Unzip the contents to the download folder.
unzip(filename,downloadFolder);
if ~doTraining
    modelFilename = fullfile(downloadFolder, modelFilename);
end

Оценка обученной сети

Используйте набор тестовых данных для оценки точности обученных подсетей.

Загрузите лучшую модель, сохраненную во время обучения.

d = load(modelFilename);
dlnetRGB = d.data.dlnetRGB;
dlnetFlow = d.data.dlnetFlow;

Создайте minibatchqueue объект для загрузки пакетов тестовых данных.

numOutputs = 3;
mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

Для каждого пакета тестовых данных сделайте предсказания с помощью сетей RGB и оптического потока, возьмите среднее значение предсказаний и вычислите точность предсказания с помощью матрицы неточностей.

cmat = sparse(numClasses,numClasses);
while hasdata(mbq)
    [dlRGB, dlFlow, dlY] = next(mbq);
    
    % Pass the video input as RGB and optical flow data through the
    % two-stream subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    
    % Calculate the accuracy of the predictions.
    [~,YTest] = max(dlY,[],1);
    [~,YPred] = max(dlYPred,[],1);

    cmat = aggregateConfusionMetric(cmat,YTest,YPred);
end

Вычислите среднюю точность классификации для обученных сетей.

accuracyEval = sum(diag(cmat))./sum(cmat,"all")
accuracyEval = 
      0.60909

Отобразите матрицу неточностей.

figure
chart = confusionchart(cmat,classes);

Из-за ограниченного количества обучающих выборок увеличение точности сверх 61% является сложным. Для повышения робастности сети требуется дополнительное обучение с большим набором данных. В сложение предварительное обучение на большем наборе данных, таком как Kinetics [1], может помочь улучшить результаты.

Предсказать использование нового видео

Теперь можно использовать обученные сети для предсказания действий в новых видео. Чтение и отображение видео pour.avi использование VideoReader и vision.VideoPlayer.

videoFilename = fullfile(downloadFolder, "pour.avi");

videoReader = VideoReader(videoFilename);
videoPlayer = vision.VideoPlayer;
videoPlayer.Name = "pour";

while hasFrame(videoReader)
   frame = readFrame(videoReader);
   step(videoPlayer,frame);
end
release(videoPlayer);

Используйте readRGBAndFlow вспомогательная функция, перечисленная в конце этого примера, для чтения данных RGB и оптического потока.

isDataForValidation = true;
readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);

Функция read возвращает логическое isDone значение, которое указывает, существует ли больше данных для чтения из файла. Используйте batchRGBAndFlow вспомогательная функция, заданная в конце этого примера, чтобы упаковать данные, чтобы пройти через двухпотоковые подсети, чтобы получить предсказания.

hasdata = true;
userdata = [];
YPred = [];
while hasdata
    [data,userdata,isDone] = readFcn(videoFilename,userdata);
    
    [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3));
    
    % Pass video input as RGB and optical flow data through the two-stream
    % subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    [~,YPredCurr] = max(dlYPred,[],1);
    YPred = horzcat(YPred,YPredCurr);
    hasdata = ~isDone;
end
YPred = extractdata(YPred);

Подсчитайте количество правильных предсказаний, используя histcounts, и получить предсказанное действие, используя максимальное количество правильных предсказаний.

classes = params.Classes;
counts = histcounts(YPred,1:numel(classes));
[~,clsIdx] = max(counts);
action = classes(clsIdx)
action = 
"pour"

Вспомогательные функции

inputStatistics

The inputStatistics функция принимает за вход имя папки, содержащей данные HMDB51, и вычисляет минимальное и максимальное значения для данных RGB и данных оптического потока. Минимальное и максимальное значения используются как нормализация, входы к входу уровню сетей. Эта функция также получает количество систем координат в каждом из файлов видео, которые используются позже во время обучения и проверки сети. В порядок, чтобы найти минимальное и максимальное значения для другого набора данных, используйте эту функцию с именем папки, содержащей набор данных.

function inputStats = inputStatistics(dataFolder)
    ds = createDatastore(dataFolder);
    ds.ReadFcn = @getMinMax;

    tic;
    tt = tall(ds);
    varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'};
    stats = gather(groupsummary(tt,[],{'max','min'}, varnames));
    inputStats.Filename = gather(tt.Filename);
    inputStats.NumFrames = gather(tt.NumFrames);
    inputStats.rgbMax = stats.max_rgbMax;
    inputStats.rgbMin = stats.min_rgbMin;
    inputStats.oflowMax = stats.max_oflowMax;
    inputStats.oflowMin = stats.min_oflowMin;
    save('inputStatistics.mat','inputStats');
    toc;
end

function data = getMinMax(filename)
    reader = VideoReader(filename);
    opticFlow = opticalFlowFarneback;
    data = [];
    while hasFrame(reader)
        frame = readFrame(reader);
        [rgb,oflow] = findMinMax(frame,opticFlow);
        data = assignMinMax(data, rgb, oflow);
    end

    totalFrames = floor(reader.Duration * reader.FrameRate);
    totalFrames = min(totalFrames, reader.NumFrames);
    
    [labelName, filename] = getLabelFilename(filename);
    data.Filename = fullfile(labelName, filename);
    data.NumFrames = totalFrames;

    data = struct2table(data,'AsArray',true);
end

function data = assignMinMax(data, rgb, oflow)
    if isempty(data)
        data.rgbMax = rgb.Max;
        data.rgbMin = rgb.Min;
        data.oflowMax = oflow.Max;
        data.oflowMin = oflow.Min;
        return;
    end
    data.rgbMax = max(data.rgbMax, rgb.Max);
    data.rgbMin = min(data.rgbMin, rgb.Min);

    data.oflowMax = max(data.oflowMax, oflow.Max);
    data.oflowMin = min(data.oflowMin, oflow.Min);
end

function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow)
    rgbMinMax.Max = max(rgb,[],[1,2]);
    rgbMinMax.Min = min(rgb,[],[1,2]);

    gray = rgb2gray(rgb);
    flow = estimateFlow(opticFlow,gray);
    oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude);

    oflowMinMax.Max = max(oflow,[],[1,2]);
    oflowMinMax.Min = min(oflow,[],[1,2]);
end

function ds = createDatastore(folder)    
    ds = fileDatastore(folder,...
        'IncludeSubfolders', true,...
        'FileExtensions', '.avi',...
        'UniformRead', true,...
        'ReadFcn', @getMinMax);
    disp("NumFiles: " + numel(ds.Files));
end

createFileDatastore

The createFileDatastore функция создает FileDatastore объект с указанными именами файлов. The FileDatastore объект считывает данные в 'partialfile' режим, поэтому каждое чтение может вернуть частично считанные системы координат из видео. Эта функция помогает с чтением больших видео файлов, если все системы координат не помещаются в памяти.

function datastore = createFileDatastore(filenames,inputStats,isDataForValidation)
    readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
    datastore = fileDatastore(filenames,...
        'ReadFcn',readFcn,...
        'ReadMode','partialfile');
end

readRGBAndFlow

The readRGBAndFlow функция считывает системы координат RGB, соответствующие данные оптического потока и значения меток для заданного видеосигнала файла. Во время обучения функция read считывает определенное количество систем координат согласно размеру входа сети с случайным выбором стартовой системы координат. Данные оптического потока вычисляются с начала файла видео, но пропускаются до достижения стартовой системы координат. Во время проверки все системы координат последовательно считываются, и вычисляются соответствующие данные оптического потока. Системы координат RGB и данные оптического потока случайным образом обрезаются до необходимого размера входа сети для обучения и обрезаются по центру для проверки и валидации.

function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation)
    if isempty(userdata)
        userdata.reader      = VideoReader(filename);
        userdata.batchesRead = 0;
        userdata.opticalFlow = opticalFlowFarneback;
        
        [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename);
        if isempty(totalFrames)
            totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate);
            totalFrames = min(totalFrames, userdata.reader.NumFrames);
        end
        userdata.totalFrames = totalFrames;
    end
    reader      = userdata.reader;
    totalFrames = userdata.totalFrames;
    label       = userdata.label;
    batchesRead = userdata.batchesRead;
    opticalFlow = userdata.opticalFlow;

    inputSize = inputStats.inputSize;
    H = inputSize(1);
    W = inputSize(2);
    rgbC = 3;
    flowC = 2;
    numFrames = inputSize(3);

    if numFrames > totalFrames
        numBatches = 1;
    else
        numBatches = floor(totalFrames/numFrames);
    end

    imH = userdata.reader.Height;
    imW = userdata.reader.Width;
    imsz = [imH,imW];

    if ~isDataForValidation

        augmentFcn = augmentTransform([imsz,3]);
        cropWindow = randomCropWindow2d(imsz, inputSize(1:2));
        %  1. Randomly select required number of frames,
        %     starting randomly at a specific frame.
        if numFrames >= totalFrames
            idx = 1:totalFrames;
            % Add more frames to fill in the network input size.
            additional = ceil(numFrames/totalFrames);
            idx = repmat(idx,1,additional);
            idx = idx(1:numFrames);
        else
            startIdx = randperm(totalFrames - numFrames);
            startIdx = startIdx(1);
            endIdx = startIdx + numFrames - 1;
            idx = startIdx:endIdx;
        end

        video = zeros(H,W,rgbC,numFrames);
        oflow = zeros(H,W,flowC,numFrames);
        i = 1;
        % Discard the first set of frames to initialize the optical flow.
        for ii = 1:idx(1)-1
            frame = read(reader,ii);
            getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
        end
        % Read the next set of required number of frames for training.
        for ii = idx
            frame = read(reader,ii);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
    else
        augmentFcn = @(data)(data);
        cropWindow = centerCropWindow2d(imsz, inputSize(1:2));
        toRead = min([numFrames,totalFrames]);
        video = zeros(H,W,rgbC,toRead);
        oflow = zeros(H,W,flowC,toRead);
        i = 1;
        while hasFrame(reader) && i <= numFrames
            frame = readFrame(reader);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
        if numFrames > totalFrames
            additional = ceil(numFrames/totalFrames);
            video = repmat(video,1,1,1,additional);
            oflow = repmat(oflow,1,1,1,additional);
            video = video(:,:,:,1:numFrames);
            oflow = oflow(:,:,:,1:numFrames);
        end
    end

    % The network expects the video and optical flow input in 
    % the following dlarray format: 
    % "SSSCB" ==> Height x Width x Frames x Channels x Batch
    %
    % Permute the data 
    %  from
    %      Height x Width x Channels x Frames
    %  to 
    %      Height x Width x Frames x Channels
    video = permute(video, [1,2,4,3]);
    oflow = permute(oflow, [1,2,4,3]);

    data = {video, oflow, label};

    batchesRead = batchesRead + 1;

    userdata.batchesRead = batchesRead;

    % Set the done flag to true, if the reader has read all the frames or
    % if it is training.
    done = batchesRead == numBatches || ~isDataForValidation;
end

function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow)
    rgb = augmentFcn(rgb);
    gray = rgb2gray(rgb);
    flow = estimateFlow(opticalFlow,gray);
    vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy);

    rgb = imcrop(rgb, cropWindow);
    vxvy = imcrop(vxvy, cropWindow);
    vxvy = vxvy(:,:,1:2);
end

function [label,fname] = getLabelFilename(filename)
    [folder,name,ext] = fileparts(string(filename));
    [~,label] = fileparts(folder);
    fname = name + ext;
    label = string(label);
    fname = string(fname);
end

function [totalFrames,label] = getTotalFramesAndLabel(info, filename)
    filenames = info.Filename;
    frames = info.NumFrames;
    [labelName, fname] = getLabelFilename(filename);
    idx = strcmp(filenames, fullfile(labelName,fname));
    totalFrames = frames(idx);    
    label = categorical(string(labelName), string(info.Classes));
end

augmentTransform

The augmentTransform функция создает метод увеличения со случайными лево-правыми коэффициентами поворота и масштабирования.

function augmentFcn = augmentTransform(sz)
% Randomly flip and scale the image.
tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput');

augmentFcn = @(data)augmentData(data,tform,rout);

    function data = augmentData(data,tform,rout)
        data = imwarp(data,tform,'OutputView',rout);
    end
end

modelGradients

The modelGradients функция принимает за вход мини-пакет данных RGB dlRGBсоответствующие данные оптического потока dlFlow, и соответствующий целевой dlYи возвращает соответствующие потери, градиенты потерь относительно настраиваемых параметров и точность обучения. Чтобы вычислить градиенты, вычислите modelGradients функция, использующая dlfeval функция в цикле обучения.

function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass video input as RGB and optical flow data through the two-stream
% network.
[dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB);
[dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow);

% Calculate fused loss, gradients, and accuracy for the two-stream
% predictions.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);
% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

gradientsRGB = dlgradient(loss,dlnetRGB.Learnables);
gradientsFlow = dlgradient(loss,dlnetFlow.Learnables);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

acc = gather(extractdata(sum(YTest == YPred)./numel(YTest)));

% Calculate the accuracy of the RGB and flow predictions.
[~,YTest] = max(Y,[],1);
[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest)));
accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest)));
end

doValidation

The doValidation Функция проверяет сеть с помощью данных валидации.

function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow)

    validationTime = tic;

    numOutputs = 3;
    mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

    lossValidation = [];
    numClasses = numel(params.Classes);
    cmat = sparse(numClasses,numClasses);
    cmatRGB = sparse(numClasses,numClasses);
    cmatFlow = sparse(numClasses,numClasses);
    while hasdata(mbq)

        [dlX1,dlX2,dlY] = next(mbq);

        [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);

        lossValidation = [lossValidation,loss];
        cmat = aggregateConfusionMetric(cmat,YTest,YPred);
        cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB);
        cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow);
    end
    lossValidation = mean(lossValidation);
    accValidation = sum(diag(cmat))./sum(cmat,"all");
    accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all");
    accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all");

    validationTime = toc(validationTime);
end

predictValidation

The predictValidation функция вычисляет значения потерь и предсказаний с помощью предоставленной dlnetwork объекты для RGB и данных оптического потока.

function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass the video input through the two-stream
% network.
dlYPredRGB = predict(dlnetRGB,dlRGB);
dlYPredFlow = predict(dlnetFlow,dlFlow);

% Calculate the cross-entropy separately for the two-stream
% outputs.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);

% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

end

updateDlnetwork

The updateDlnetwork функция обновляет предоставленные dlnetwork объект с градиентами и другими параметрами, использующими оптимизационную функцию SGDM sgdmupdate.

function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration)
    % Determine the learning rate using the cosine-annealing learning rate schedule.
    learnRate = cosineAnnealingLearnRate(iteration, params);

    % Apply L2 regularization to the weights.
    idx = dlnet.Learnables.Parameter == "Weights";
    gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

    % Update the network parameters using the SGDM optimizer.
    [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum);
end

cosineAnnealingLearnRate

The cosineAnnealingLearnRate функция вычисляет скорость обучения на основе текущего числа итераций, минимальной скорости обучения, максимальной скорости обучения и количества итераций для отжига [3].

function lr = cosineAnnealingLearnRate(iteration, params)
    if iteration == params.NumIterations
        lr = params.MinLearningRate;
        return;
    end
    cosineNumIter = [0, params.CosineNumIterations];
    csum = cumsum(cosineNumIter);
    block = find(csum >= iteration, 1,'first');
    cosineIter = iteration - csum(block - 1);
    annealingIteration = mod(cosineIter, cosineNumIter(block));
    cosineIteration = cosineNumIter(block);
    minR = params.MinLearningRate;
    maxR = params.MaxLearningRate;
    cosMult = 1 + cos(pi * annealingIteration / cosineIteration);
    lr = minR + ((maxR - minR) *  cosMult / 2);
end

aggregateConfusionMetric

The aggregateConfusionMetric функция пошагово заполняет матрицу неточностей на основе предсказанных результатов YPred и ожидаемые результаты YTest.

function cmat = aggregateConfusionMetric(cmat,YTest,YPred)
YTest = gather(extractdata(YTest));
YPred = gather(extractdata(YPred));
[m,n] = size(cmat);
cmat = cmat + full(sparse(YTest,YPred,1,m,n));
end

createMiniBatchQueue

The createMiniBatchQueue функция создает minibatchqueue объект, который обеспечивает miniBatchSize объем данных из данного datastore. Это также создает DispatchInBackgroundDatastore если параллельный пул открыт.

function mbq = createMiniBatchQueue(datastore, numOutputs, params)
if params.DispatchInBackground && isempty(gcp('nocreate'))
    % Start a parallel pool, if DispatchInBackground is true, to dispatch
    % data in the background using the parallel pool.
    c = parcluster('local');
    c.NumWorkers = params.NumWorkers;
    parpool('local',params.NumWorkers);
end
p = gcp('nocreate');
if ~isempty(p)
    datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers);
end
inputFormat(1:numOutputs-1) = "SSSCB";
outputFormat = "CB";
mbq = minibatchqueue(datastore, numOutputs, ...
    "MiniBatchSize", params.MiniBatchSize, ...
    "MiniBatchFcn", @batchRGBAndFlow, ...
    "MiniBatchFormat", [inputFormat,outputFormat]);
end

batchRGBAndFlow

The batchRGBAndFlow функция пакетирует изображение, поток и данные о метках в соответствующие dlarray значения в форматах данных "SSSCB", "SSSCB", и "CB", соответственно.

function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels)
% Batch dimension: 5
X1 = cat(5,images{:});
X2 = cat(5,flows{:});

% Batch dimension: 2
labels = cat(2,labels{:});

% Feature dimension: 1
Y = onehotencode(labels,1);

% Cast data to single for processing.
X1 = single(X1);
X2 = single(X2);
Y = single(Y);

% Move data to the GPU if possible.
if canUseGPU
    X1 = gpuArray(X1);
    X2 = gpuArray(X2);
    Y = gpuArray(Y);
end

% Return X and Y as dlarray objects.
dlX1 = dlarray(X1,"SSSCB");
dlX2 = dlarray(X2,"SSSCB");
dlY = dlarray(Y,"CB");
end

shuffleTrainDs

The shuffleTrainDs функция тасует файлы, существующие в обучающем datastore dsTrain.

function shuffled = shuffleTrainDs(dsTrain)
shuffled = copy(dsTrain);
n = numel(shuffled.Files);
shuffledIndices = randperm(n);
shuffled.Files = shuffled.Files(shuffledIndices);
reset(shuffled);
end

saveData

The saveData функция сохраняет заданную dlnetwork объекты и значения точности в MAT файла.

function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation)
dlnetRGB = gatherFromGPUToSave(dlnetRGB);
dlnetFlow = gatherFromGPUToSave(dlnetFlow);
data.ValidationAccuracy = accValidation;
data.cmat = cmat;
data.dlnetRGB = dlnetRGB;
data.dlnetFlow = dlnetFlow;
save(modelFilename, 'data');
end

gatherFromGPUToSave

The gatherFromGPUToSave функция собирает данные из графического процессора в порядок, чтобы сохранить модель на диск.

function dlnet = gatherFromGPUToSave(dlnet)
if ~canUseGPU
    return;
end
dlnet.Learnables = gatherValues(dlnet.Learnables);
dlnet.State = gatherValues(dlnet.State);
    function tbl = gatherValues(tbl)
        for ii = 1:height(tbl)
            tbl.Value{ii} = gather(tbl.Value{ii});
        end
    end
end

checkForHMDB51Folder

The checkForHMDB51Folder функция проверяет загруженные данные в папке загрузки.

function classes = checkForHMDB51Folder(dataLoc)
hmdbFolder = fullfile(dataLoc, "hmdb51_org");
if ~exist(hmdbFolder, "dir")
    error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");    
end

classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",...
    "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",...
    "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",...
    "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",...
    "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",...
    "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",...
    "sword_exercise","talk","throw","turn","walk","wave"];
expectFolders = fullfile(hmdbFolder, classes);
if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders))
    error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");
end
end

downloadHMDB51

The downloadHMDB51 функция загружает набор данных и сохраняет его в директорию.

function downloadHMDB51(dataLoc)

if nargin == 0
    dataLoc = pwd;
end
dataLoc = string(dataLoc);

if ~exist(dataLoc,"dir")
    mkdir(dataLoc);
end

dataUrl     = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar";
options     = weboptions('Timeout', Inf);
rarFileName = fullfile(dataLoc, 'hmdb51_org.rar');
fileExists  = exist(rarFileName, 'file');

% Download the RAR file and save it to the download folder.
if ~fileExists
    disp("Downloading hmdb51_org.rar (2 GB) to the folder:")
    disp(dataLoc)
    disp("This download can take a few minutes...") 
    websave(rarFileName, dataUrl, options); 
    disp("Download complete.")
    disp("Extract the hmdb51_org.rar file contents to the folder: ") 
    disp(dataLoc)
end
end

initializeTrainingProgressPlot

The initializeTrainingProgressPlot функция конфигурирует два графика для отображения потерь обучения, точности обучения и точности валидации.

function plotters = initializeTrainingProgressPlot(params)
if params.ProgressPlot
    % Plot the loss, training accuracy, and validation accuracy.
    figure
    
    % Loss plot
    subplot(2,1,1)
    plotters.LossPlotter = animatedline;
    xlabel("Iteration")
    ylabel("Loss")
    
    % Accuracy plot
    subplot(2,1,2)
    plotters.TrainAccPlotter = animatedline('Color','b');
    plotters.ValAccPlotter = animatedline('Color','g');
    legend('Training Accuracy','Validation Accuracy','Location','northwest');
    xlabel("Iteration")
    ylabel("Accuracy")
else
    plotters = [];
end
end

initializeVerboseOutput

The initializeVerboseOutput Функция отображает заголовки столбца для таблицы обучающих значений, которая показывает эпоху, точность мини-пакетов и другие значения обучения.

function initializeVerboseOutput(params)
if params.Verbose
    disp(" ")
    if canUseGPU
        disp("Training on GPU.")
    else
        disp("Training on CPU.")
    end
    p = gcp('nocreate');
    if ~isempty(p)
        disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ")
    end
    disp("NumIterations:" + string(params.NumIterations));
    disp("MiniBatchSize:" + string(params.MiniBatchSize));
    disp("Classes:" + join(string(params.Classes), ","));    
    disp("|=======================================================================================================================================================================|")
    disp("| Epoch | Iteration | Time Elapsed |     Mini-Batch Accuracy    |    Validation Accuracy     | Mini-Batch | Validation |  Base Learning  | Train Time | Validation Time |")
    disp("|       |           |  (hh:mm:ss)  |       (Avg:RGB:Flow)       |       (Avg:RGB:Flow)       |    Loss    |    Loss    |      Rate       | (hh:mm:ss) |   (hh:mm:ss)    |")
    disp("|=======================================================================================================================================================================|")
end
end

displayVerboseOutputEveryEpoch

The displayVerboseOutputEveryEpoch функция отображает подробные выходы обучающих значений, таких как эпоха, точность мини-пакета, точность валидации и потеря мини-пакета.

function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
        accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime)
    if params.Verbose
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        trainTime = duration(0,0,trainTime,'Format','hh:mm:ss');
        validationTime = duration(0,0,validationTime,'Format','hh:mm:ss');

        lossValidation = gather(extractdata(lossValidation));
        lossValidation = compose('%.4f',lossValidation);

        accValidation = composePadAccuracy(accValidation);
        accValidationRGB = composePadAccuracy(accValidationRGB);
        accValidationFlow = composePadAccuracy(accValidationFlow);

        accVal = join([accValidation,accValidationRGB,accValidationFlow], " : ");

        lossTrain = gather(extractdata(lossTrain));
        lossTrain = compose('%.4f',lossTrain);

        accTrain = composePadAccuracy(accTrain);
        accTrainRGB = composePadAccuracy(accTrainRGB);
        accTrainFlow = composePadAccuracy(accTrainFlow);

        accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : ");
        learnRate = compose('%.13f',learnRate);

        disp("| " + ...
            pad(string(epoch),5,'both') + " | " + ...
            pad(string(iteration),9,'both') + " | " + ...
            pad(string(D),12,'both') + " | " + ...
            pad(string(accTrain),26,'both') + " | " + ...
            pad(string(accVal),26,'both') + " | " + ...
            pad(string(lossTrain),10,'both') + " | " + ...
            pad(string(lossValidation),10,'both') + " | " + ...
            pad(string(learnRate),13,'both') + " | " + ...
            pad(string(trainTime),10,'both') + " | " + ...
            pad(string(validationTime),15,'both') + " |")
    end

end

function acc = composePadAccuracy(acc)
    acc = compose('%.2f',acc*100) + "%";
    acc = pad(string(acc),6,'left');
end

endVerboseOutput

The endVerboseOutput функция отображает конец подробного выхода во время обучения.

function endVerboseOutput(params)
if params.Verbose
    disp("|=======================================================================================================================================================================|")        
end
end

updateProgressPlot

The updateProgressPlot функция обновляет график прогресса с информацией о потерях и точности во время обучения.

function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation)
if params.ProgressPlot
    
    % Update the training progress.
    D = duration(0,0,toc(start),"Format","hh:mm:ss");
    title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D));
    addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain))));
    addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain);
    addpoints(plotters.ValAccPlotter,iteration,accuracyValidation);
    drawnow
end
end

Ссылки

[1] Каррейра, Жоао и Эндрю Зиссерман. "Кво Вадис, признание действий? Новая модель и набор данных кинетики ". Материалы Конференции IEEE по компьютерному зрению и распознаванию шаблонов (CVPR): 6299?? 6308. Гонолулу, HI: IEEE, 2017.

[2] Симоньян, Карен и Эндрю Зиссерман. Двухпоточные сверточные сети для распознавания действий в видео. Усовершенствования в системах обработки нейронной информации 27, Лонг-Бич, Калифорния: NIPS, 2017.

[3] Лощилов, Илья и Фрэнк Хаттер. SGDR: Стохастический градиентный спуск с теплыми перезапусками. Международная конференция по учебным представлениям 2017. Тулон, Франция: ICLR, 2017.