exponenta event banner

Распознавание активности из видео и оптических потоковых данных с помощью глубокого обучения

В этом примере показано, как обучить двухпотоковую сверточную нейронную сеть Fluted 3-D (I3D) распознаванию активности с использованием RGB и данных оптического потока из видео [1].

Распознавание активности на основе зрения включает в себя прогнозирование действия объекта, такого как ходьба, плавание или сидение, с использованием набора видеокадров. Распознавание активности из видео имеет множество приложений, таких как взаимодействие человека с компьютером, обучение роботов, обнаружение аномалий, наблюдение и обнаружение объектов. Например, онлайн-прогнозирование нескольких действий для входящих видео с нескольких камер может быть важным для обучения роботов. По сравнению с классификацией изображений, распознавание действий с использованием видео является сложной задачей для моделирования из-за шумных меток в наборах видеоданных, разнообразия действий, которые актеры в видео могут выполнять, которые сильно разбалансированы классом, и неэффективности вычислений при предварительной подготовке на больших наборах видеоданных. Некоторые методы глубокого обучения, такие как I3D двухпотоковые сверточные сети [1], показали улучшенную производительность за счет использования предварительной подготовки на больших наборах данных классификации изображений.

Загрузить данные

В этом примере выполняется обучение сети I3D с использованием набора данных HMDB51. Используйте downloadHMDB51 вспомогательная функция, перечисленная в конце этого примера, для загрузки набора данных HMDB51 в папку с именем hmdb51.

downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);

После завершения загрузки извлеките файл RAR. hmdb51_org.rar в hmdb51 папка. Далее используйте checkForHMDB51Folder вспомогательная функция, перечисленная в конце этого примера, для подтверждения наличия загруженных и извлеченных файлов.

allClasses = checkForHMDB51Folder(downloadFolder);

Набор данных содержит около 2 ГБ видеоданных для 7000 клипов более 51 класса, таких как выпивка, бег и пожимание рук. Каждый видеокадр имеет высоту 240 пикселей и минимальную ширину 176 пикселей. Количество кадров колеблется от 18 до приблизительно 1000.

Чтобы сократить время обучения, в этом примере обучается сеть распознавания действий для классификации 5 классов действий вместо всех 51 класса в наборе данных. Набор useAllData кому true тренироваться со всеми 51 классами.

useAllData = false;

if useAllData
    classes = allClasses;
else
    classes = ["kiss","laugh","pick","pour","pushup"];
end
dataFolder = fullfile(downloadFolder, "hmdb51_org");

Разбейте набор данных на обучающий набор для обучения сети и тестовый набор для оценки сети. Используйте 80% данных для обучающего набора, а остальные - для тестового набора. Использовать imageDatastore разделить данные на основе каждой метки на наборы обучающих и тестовых данных путем случайного выбора доли файлов из каждой метки.

imds = imageDatastore(fullfile(dataFolder,classes),...
    'IncludeSubfolders', true,...
    'LabelSource', 'foldernames',...
    'FileExtensions', '.avi');

[trainImds,testImds] = splitEachLabel(imds,0.8,'randomized');

trainFilenames = trainImds.Files;
testFilenames  = testImds.Files;

Для нормализации входных данных для сети в файле MAT предоставляются минимальное и максимальное значения для набора данных. inputStatistics.mat, прилагается к этому примеру. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте inputStatistics вспомогательная функция, перечисленная в конце этого примера.

inputStatsFilename = 'inputStatistics.mat';
if ~exist(inputStatsFilename, 'file')
    disp("Reading all the training data for input statistics...")
    inputStats = inputStatistics(dataFolder);
else
    d = load(inputStatsFilename);
    inputStats = d.inputStats;    
end

Создание хранилищ данных для учебных сетей

Создать два FileDatastore объекты для обучения и проверки с использованием createFileDatastore вспомогательная функция, определенная в конце этого примера. Каждое хранилище данных считывает видеофайл для предоставления данных RGB, данных оптического потока и соответствующей информации метки.

Укажите количество кадров для каждого чтения хранилищем данных. Типичными значениями являются 16, 32, 64 или 128. Использование большего количества кадров помогает захватить больше временной информации, но требует больше памяти для обучения и прогнозирования. Установите для количества кадров значение 64, чтобы сбалансировать использование памяти с производительностью. Возможно, потребуется снизить это значение в зависимости от системных ресурсов.

numFrames = 64;

Укажите высоту и ширину кадров для считываемого хранилища данных. Установка одинаковых значений высоты и ширины упрощает пакетирование данных для сети. Типичными значениями являются [112, 112], [224, 224] и [256, 256]. Минимальная высота и ширина видеокадров в наборе данных HMDB51 составляют 240 и 176 соответственно. Определите [112, 112], чтобы захватить большее число структур за счет пространственной информации. Если требуется указать размер кадра для считываемого хранилища данных, превышающий минимальные значения, например [256, 256], сначала измените размер кадров с помощью imresize.

frameSize = [112,112];

Набор inputSize в inputStats структура таким образом функция считывания fileDatastore может считывать указанный размер ввода.

inputSize = [frameSize, numFrames];
inputStats.inputSize = inputSize;
inputStats.Classes = classes;

Создать два FileDatastore объекты, один для обучения и другой для проверки.

isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);

isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);

disp("Training data size: " + string(numel(dsTrain.Files)))
Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))
Validation data size: 109

Определение сетевой архитектуры

I3D сеть

Использование 3-D CNN является естественным подходом к извлечению пространственно-временных особенностей из видео. Вы можете создать сеть I3D из предварительно обученной 2-й сети классификации изображений, такой как Начало v1 или ResNet-50, расширив 2-е фильтры и объединив ядра в 3D. Эта процедура повторно использует веса, полученные из задачи классификации образов, для начальной загрузки задачи распознавания видео.

На следующем рисунке показан пример раздувания слоя свертки 2-D до слоя свертки 3-D. Инфляция включает в себя расширение размера фильтра, весов и смещения путем добавления третьего измерения (временного измерения).

Двухпотоковая сеть I3D

Можно считать, что видеоданные имеют две части: пространственную и временную.

  • Пространственный компонент содержит информацию о форме, текстуре и цвете объектов в видео. Данные RGB содержат эту информацию.

  • Временная составляющая содержит информацию о движении объектов по кадрам и изображает важные перемещения между камерой и объектами в сцене. Вычисление оптического потока является обычным способом извлечения временной информации из видео.

Двухпоточный CNN включает пространственную подсеть и временную подсеть [2]. Сверточная нейронная сеть, обученная плотному оптическому потоку и потоку видеоданных, может достичь лучшей производительности при ограниченных обучающих данных, чем при необработанных кадрах RGB. На следующем рисунке показана типичная двухпотоковая сеть I3D.

Создание двухпотоковой сети I3D

В этом примере создается I3D сеть с использованием GoogLeNet, сети, предварительно подготовленной в базе данных ImageNet.

Укажите количество каналов как 3 для подсети RGB, и 2 для подсети оптического потока. Два канала для данных оптического потока являются x и y компонентами скорости, Vx и Vy соответственно.

rgbChannels = 3;
flowChannels = 2;

Получение минимального и максимального значений для RGB и данных оптического потока из inputStats структура загружена из inputStatistics.mat файл. Эти значения необходимы для image3dInputLayer I3D сетей для нормализации входных данных.

rgbInputSize = [frameSize, numFrames, rgbChannels];
flowInputSize = [frameSize, numFrames, flowChannels];

rgbMin = inputStats.rgbMin;
rgbMax = inputStats.rgbMax;
oflowMin = inputStats.oflowMin(:,:,1:2);
oflowMax = inputStats.oflowMax(:,:,1:2);

rgbMin = reshape(rgbMin,[1,size(rgbMin)]);
rgbMax = reshape(rgbMax,[1,size(rgbMax)]);
oflowMin = reshape(oflowMin,[1,size(oflowMin)]);
oflowMax = reshape(oflowMax,[1,size(oflowMax)]);

Укажите количество занятий для обучения сети.

numClasses = numel(classes);

Создайте подсети I3D RGB и оптического потока с помощью Inflated3D вспомогательная функция, которая присоединена к этому примеру. Подсети создаются из GoogLeNet.

cnnNet = googlenet;

netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet);
netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);

Создать dlnetwork объект из графа уровней каждой из I3D сетей.

dlnetRGB = dlnetwork(netRGB);
dlnetFlow = dlnetwork(netFlow);

Определение функции градиентов модели

Создание вспомогательной функции modelGradients, перечисленных в конце этого примера. modelGradients функция принимает в качестве входных данных подсеть RGB dlnetRGB, подсеть оптического потока dlnetFlow, мини-пакет входных данных dlRGB и dlFlowи мини-пакет данных метки истинности земли dlY. Функция возвращает значение обучающих потерь, градиенты потерь относительно обучаемых параметров соответствующих подсетей и точность мини-пакета подсетей.

Потери вычисляются путем вычисления среднего значения потерь перекрестной энтропии предсказаний из каждой из подсетей. Выходные прогнозы сети являются вероятностями между 0 и 1 для каждого из классов.

raseLoss = перекрестная энтропия (reyPrediction)

flowLoss = перекрестная энтропия (flowPrediction)

loss = среднее значение ([rureLoss, flowLoss])

Точность каждой из подсетей вычисляется путем взятия среднего значения предсказаний RGB и оптического потока и сравнения его с нулевой меткой истинности входных данных.

Укажите параметры обучения

Поезд с размером мини-партии 20 на 1500 итераций. Укажите итерацию, после которой сохранить модель с наилучшей точностью проверки с помощью SaveBestAfterIteration параметр.

Укажите параметры графика скорости обучения косинусному отжигу [3]. Для обеих сетей используйте:

  • Минимальная скорость обучения 1e-4.

  • Максимальная скорость обучения 1e-3.

  • Косинусное число итераций 300, 500 и 700, после чего цикл планирования скорости обучения перезапускается. Выбор CosineNumIterations определяет ширину каждого косинусного цикла.

Укажите параметры для оптимизации SGDM. Инициализируйте параметры оптимизации SGDM в начале обучения для каждой из сетей RGB и оптических потоков. Для обеих сетей используйте:

  • Импульс 0,9.

  • Начальный параметр скорости, инициализированный как [].

  • Коэффициент регуляции L2 0,0005.

Используется для отправки данных в фоновом режиме с использованием параллельного пула. Если DispatchInBackground имеет значение true, открывает параллельный пул с указанным числом параллельных работников и создает DispatchInBackgroundDatastore, предоставленный в качестве части этого примера, который отправляет данные в фоновом режиме для ускорения обучения с использованием асинхронной загрузки и предварительной обработки данных. По умолчанию в этом примере используется графический процессор, если он доступен. В противном случае используется ЦП. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Сведения о поддерживаемых вычислительных возможностях см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

params.Classes = classes;
params.MiniBatchSize = 20;
params.NumIterations = 1500;
params.SaveBestAfterIteration = 900;
params.CosineNumIterations = [300, 500, 700];
params.MinLearningRate = 1e-4;
params.MaxLearningRate = 1e-3;
params.Momentum = 0.9;
params.VelocityRGB = [];
params.VelocityFlow = [];
params.L2Regularization = 0.0005;
params.ProgressPlot = false;
params.Verbose = true;
params.ValidationData = dsVal;
params.DispatchInBackground = false;
params.NumWorkers = 4;

Железнодорожная сеть

Обучение подсетей с использованием данных RGB и данных оптического потока. Установите doTraining переменная для false для загрузки предварительно подготовленных подсетей без необходимости ожидания завершения обучения. Кроме того, если требуется обучить подсети, установите doTraining переменная для true.

doTraining = false;

Для каждой эпохи:

  • Перетасовка данных перед закольцовыванием по мини-пакетам данных.

  • Использовать minibatchqueue для закольцовывания мини-партий. Вспомогательная функция createMiniBatchQueue, перечисленных в конце этого примера, использует данное хранилище данных обучения для создания minibatchqueue.

  • Использовать данные проверки dsVal для проверки сетей.

  • Отображение результатов потерь и точности для каждой эпохи с помощью вспомогательной функции displayVerboseOutputEveryEpoch, перечисленных в конце этого примера.

Для каждой мини-партии:

  • Преобразование данных изображения или данных оптического потока и меток в dlarray объекты с базовым типом одиночный.

  • Рассматривать временную размерность видеоданных и данных оптического потока как одну из пространственных размерностей для обеспечения возможности обработки с использованием 3-D CNN. Указание меток размеров "SSSCB" (пространственные, пространственные, пространственные, канальные, пакетные) для данных RGB или оптического потока, и "CB" для данных метки.

minibatchqueue объект использует вспомогательную функцию batchRGBAndFlow, перечисленных в конце этого примера, для пакетной обработки данных RGB и оптического потока.

modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat";
if doTraining 
    epoch = 1;
    bestValAccuracy = 0;
    accTrain = [];
    accTrainRGB = [];
    accTrainFlow = [];
    lossTrain = [];
        
    iteration = 1;
    shuffled = shuffleTrainDs(dsTrain);
    
    % Number of outputs is three: One for RGB frames, one for optical flow
    % data, and one for ground truth labels.
    numOutputs = 3;
    mbq = createMiniBatchQueue(shuffled, numOutputs, params);
    start = tic;
    trainTime = start;
    
    % Use the initializeTrainingProgressPlot and initializeVerboseOutput
    % supporting functions, listed at the end of the example, to initialize
    % the training progress plot and verbose output to display the training
    % loss, training accuracy, and validation accuracy.
    plotters = initializeTrainingProgressPlot(params);
    initializeVerboseOutput(params);
    
    while iteration <= params.NumIterations

        % Iterate through the data set.
        [dlX1,dlX2,dlY] = next(mbq);

        % Evaluate the model gradients and loss using dlfeval.
        [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ...
            dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);
        
        % Accumulate the loss and accuracies.
        lossTrain = [lossTrain, loss];
        accTrain = [accTrain, acc];
        accTrainRGB = [accTrainRGB, accRGB];
        accTrainFlow = [accTrainFlow, accFlow];
        % Update the network state.
        dlnetRGB.State = stateRGB;
        dlnetFlow.State = stateFlow;
        
        % Update the gradients and parameters for the RGB and optical flow
        % subnetworks using the SGDM optimizer.
        [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ...
            updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration);
        [dlnetFlow,gradFlow,params.VelocityFlow] = ...
            updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration);
        
        if ~hasdata(mbq) || iteration == params.NumIterations
            % Current epoch is complete. Do validation and update progress.
            trainTime = toc(trainTime);

            [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ...
                doValidation(params, dlnetRGB, dlnetFlow);

            % Update the training progress.
            displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
                mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),...
                accValidation,accValidationRGB,accValidationFlow,...
                mean(lossTrain),lossValidation,trainTime,validationTime);
            updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation);
            
            % Save model with the trained dlnetwork and accuracy values.
            % Use the saveData supporting function, listed at the
            % end of this example.
            if iteration >= params.SaveBestAfterIteration
                if accValidation > bestValAccuracy
                    bestValAccuracy = accValidation;
                    saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation);
                end
            end
        end
        
        if ~hasdata(mbq) && iteration < params.NumIterations
            % Current epoch is complete. Initialize the training loss, accuracy
            % values, and minibatchqueue for the next epoch.
            accTrain = [];
            accTrainRGB = [];
            accTrainFlow = [];
            lossTrain = [];
        
            trainTime = tic;
            epoch = epoch + 1;
            shuffled = shuffleTrainDs(dsTrain);
            numOutputs = 3;
            mbq = createMiniBatchQueue(shuffled, numOutputs, params);
            
        end 
        
        iteration = iteration + 1;
    end
    
    % Display a message when training is complete.
    endVerboseOutput(params);
    
    disp("Model saved to: " + modelFilename);
end

% Download the pretrained model and video file for prediction.
filename = "activityRecognition-I3D-HMDB51.zip";
downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename;

filename = fullfile(downloadFolder,filename);
if ~exist(filename,'file')
    disp('Downloading the pretrained network...');
    websave(filename,downloadURL);
end
% Unzip the contents to the download folder.
unzip(filename,downloadFolder);
if ~doTraining
    modelFilename = fullfile(downloadFolder, modelFilename);
end

Оценка обученной сети

Используйте набор тестовых данных для оценки точности обученных подсетей.

Загрузите лучшую модель, сохраненную во время обучения.

d = load(modelFilename);
dlnetRGB = d.data.dlnetRGB;
dlnetFlow = d.data.dlnetFlow;

Создать minibatchqueue объект для загрузки пакетов тестовых данных.

numOutputs = 3;
mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

Для каждой партии тестовых данных сделайте прогнозы с использованием RGB и оптических потоковых сетей, возьмите среднее значение прогнозов и вычислите точность прогнозирования с использованием матрицы путаницы.

cmat = sparse(numClasses,numClasses);
while hasdata(mbq)
    [dlRGB, dlFlow, dlY] = next(mbq);
    
    % Pass the video input as RGB and optical flow data through the
    % two-stream subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    
    % Calculate the accuracy of the predictions.
    [~,YTest] = max(dlY,[],1);
    [~,YPred] = max(dlYPred,[],1);

    cmat = aggregateConfusionMetric(cmat,YTest,YPred);
end

Вычислите среднюю точность классификации для обученных сетей.

accuracyEval = sum(diag(cmat))./sum(cmat,"all")
accuracyEval = 
      0.60909

Отображение матрицы путаницы.

figure
chart = confusionchart(cmat,classes);

Из-за ограниченного количества обучающих образцов повышение точности выше 61% является сложной задачей. Для повышения надежности сети требуется дополнительное обучение с большим набором данных. Кроме того, предварительная подготовка большего набора данных, например, Kinetics [1], может помочь улучшить результаты.

Прогнозирование с помощью нового видео

Теперь можно использовать обученные сети для прогнозирования действий в новых видео. Чтение и отображение видео pour.avi использование VideoReader и vision.VideoPlayer.

videoFilename = fullfile(downloadFolder, "pour.avi");

videoReader = VideoReader(videoFilename);
videoPlayer = vision.VideoPlayer;
videoPlayer.Name = "pour";

while hasFrame(videoReader)
   frame = readFrame(videoReader);
   step(videoPlayer,frame);
end
release(videoPlayer);

Используйте readRGBAndFlow поддерживающая функция, перечисленная в конце этого примера, для считывания данных RGB и оптического потока.

isDataForValidation = true;
readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);

Функция считывания возвращает логическое значение isDone значение, указывающее на наличие дополнительных данных для чтения из файла. Используйте batchRGBAndFlow поддерживающая функция, определенная в конце этого примера, для пакетной передачи данных через двухпотоковые подсети для получения прогнозов.

hasdata = true;
userdata = [];
YPred = [];
while hasdata
    [data,userdata,isDone] = readFcn(videoFilename,userdata);
    
    [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3));
    
    % Pass video input as RGB and optical flow data through the two-stream
    % subnetworks to get the separate predictions.
    dlYPredRGB = predict(dlnetRGB,dlRGB);
    dlYPredFlow = predict(dlnetFlow,dlFlow);

    % Fuse the predictions by calculating the average of the predictions.
    dlYPred = (dlYPredRGB + dlYPredFlow)/2;
    [~,YPredCurr] = max(dlYPred,[],1);
    YPred = horzcat(YPred,YPredCurr);
    hasdata = ~isDone;
end
YPred = extractdata(YPred);

Подсчитать количество правильных прогнозов с помощью histcountsи получают предсказанное действие, используя максимальное количество правильных прогнозов.

classes = params.Classes;
counts = histcounts(YPred,1:numel(classes));
[~,clsIdx] = max(counts);
action = classes(clsIdx)
action = 
"pour"

Вспомогательные функции

inputStatistics

inputStatistics функция принимает в качестве входных данных имя папки, содержащей HMDB51 данные, и вычисляет минимальное и максимальное значения для данных RGB и данных оптического потока. Минимальное и максимальное значения используются в качестве входных данных нормализации для входного уровня сетей. Эта функция также получает количество кадров в каждом из видеофайлов для последующего использования во время обучения и тестирования сети. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте эту функцию с именем папки, содержащей набор данных.

function inputStats = inputStatistics(dataFolder)
    ds = createDatastore(dataFolder);
    ds.ReadFcn = @getMinMax;

    tic;
    tt = tall(ds);
    varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'};
    stats = gather(groupsummary(tt,[],{'max','min'}, varnames));
    inputStats.Filename = gather(tt.Filename);
    inputStats.NumFrames = gather(tt.NumFrames);
    inputStats.rgbMax = stats.max_rgbMax;
    inputStats.rgbMin = stats.min_rgbMin;
    inputStats.oflowMax = stats.max_oflowMax;
    inputStats.oflowMin = stats.min_oflowMin;
    save('inputStatistics.mat','inputStats');
    toc;
end

function data = getMinMax(filename)
    reader = VideoReader(filename);
    opticFlow = opticalFlowFarneback;
    data = [];
    while hasFrame(reader)
        frame = readFrame(reader);
        [rgb,oflow] = findMinMax(frame,opticFlow);
        data = assignMinMax(data, rgb, oflow);
    end

    totalFrames = floor(reader.Duration * reader.FrameRate);
    totalFrames = min(totalFrames, reader.NumFrames);
    
    [labelName, filename] = getLabelFilename(filename);
    data.Filename = fullfile(labelName, filename);
    data.NumFrames = totalFrames;

    data = struct2table(data,'AsArray',true);
end

function data = assignMinMax(data, rgb, oflow)
    if isempty(data)
        data.rgbMax = rgb.Max;
        data.rgbMin = rgb.Min;
        data.oflowMax = oflow.Max;
        data.oflowMin = oflow.Min;
        return;
    end
    data.rgbMax = max(data.rgbMax, rgb.Max);
    data.rgbMin = min(data.rgbMin, rgb.Min);

    data.oflowMax = max(data.oflowMax, oflow.Max);
    data.oflowMin = min(data.oflowMin, oflow.Min);
end

function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow)
    rgbMinMax.Max = max(rgb,[],[1,2]);
    rgbMinMax.Min = min(rgb,[],[1,2]);

    gray = rgb2gray(rgb);
    flow = estimateFlow(opticFlow,gray);
    oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude);

    oflowMinMax.Max = max(oflow,[],[1,2]);
    oflowMinMax.Min = min(oflow,[],[1,2]);
end

function ds = createDatastore(folder)    
    ds = fileDatastore(folder,...
        'IncludeSubfolders', true,...
        'FileExtensions', '.avi',...
        'UniformRead', true,...
        'ReadFcn', @getMinMax);
    disp("NumFiles: " + numel(ds.Files));
end

createFileDatastore

createFileDatastore функция создает FileDatastore с использованием заданных имен файлов. FileDatastore объект считывает данные в 'partialfile' режим, так что каждое чтение может возвращать частично считанные кадры из видео. Эта функция помогает при чтении больших видеофайлов, если все кадры не помещаются в память.

function datastore = createFileDatastore(filenames,inputStats,isDataForValidation)
    readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
    datastore = fileDatastore(filenames,...
        'ReadFcn',readFcn,...
        'ReadMode','partialfile');
end

readRGBAndFlow

readRGBAndFlow функция считывает кадры RGB, соответствующие данные оптического потока и значения меток для данного видеофайла. Во время обучения функция считывания считывает определенное количество кадров в соответствии с размером сетевого ввода с произвольно выбранным начальным кадром. Данные оптического потока вычисляются с начала видеофайла, но пропускаются до достижения начального кадра. Во время тестирования все кадры последовательно считываются и вычисляются соответствующие данные оптического потока. Кадры RGB и данные оптического потока случайным образом обрезаются до требуемого размера сетевого ввода для обучения, а центр обрезается для тестирования и проверки.

function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation)
    if isempty(userdata)
        userdata.reader      = VideoReader(filename);
        userdata.batchesRead = 0;
        userdata.opticalFlow = opticalFlowFarneback;
        
        [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename);
        if isempty(totalFrames)
            totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate);
            totalFrames = min(totalFrames, userdata.reader.NumFrames);
        end
        userdata.totalFrames = totalFrames;
    end
    reader      = userdata.reader;
    totalFrames = userdata.totalFrames;
    label       = userdata.label;
    batchesRead = userdata.batchesRead;
    opticalFlow = userdata.opticalFlow;

    inputSize = inputStats.inputSize;
    H = inputSize(1);
    W = inputSize(2);
    rgbC = 3;
    flowC = 2;
    numFrames = inputSize(3);

    if numFrames > totalFrames
        numBatches = 1;
    else
        numBatches = floor(totalFrames/numFrames);
    end

    imH = userdata.reader.Height;
    imW = userdata.reader.Width;
    imsz = [imH,imW];

    if ~isDataForValidation

        augmentFcn = augmentTransform([imsz,3]);
        cropWindow = randomCropWindow2d(imsz, inputSize(1:2));
        %  1. Randomly select required number of frames,
        %     starting randomly at a specific frame.
        if numFrames >= totalFrames
            idx = 1:totalFrames;
            % Add more frames to fill in the network input size.
            additional = ceil(numFrames/totalFrames);
            idx = repmat(idx,1,additional);
            idx = idx(1:numFrames);
        else
            startIdx = randperm(totalFrames - numFrames);
            startIdx = startIdx(1);
            endIdx = startIdx + numFrames - 1;
            idx = startIdx:endIdx;
        end

        video = zeros(H,W,rgbC,numFrames);
        oflow = zeros(H,W,flowC,numFrames);
        i = 1;
        % Discard the first set of frames to initialize the optical flow.
        for ii = 1:idx(1)-1
            frame = read(reader,ii);
            getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
        end
        % Read the next set of required number of frames for training.
        for ii = idx
            frame = read(reader,ii);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
    else
        augmentFcn = @(data)(data);
        cropWindow = centerCropWindow2d(imsz, inputSize(1:2));
        toRead = min([numFrames,totalFrames]);
        video = zeros(H,W,rgbC,toRead);
        oflow = zeros(H,W,flowC,toRead);
        i = 1;
        while hasFrame(reader) && i <= numFrames
            frame = readFrame(reader);
            [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow);
            video(:,:,:,i) = rgb;
            oflow(:,:,:,i) = vxvy;
            i = i + 1;
        end
        if numFrames > totalFrames
            additional = ceil(numFrames/totalFrames);
            video = repmat(video,1,1,1,additional);
            oflow = repmat(oflow,1,1,1,additional);
            video = video(:,:,:,1:numFrames);
            oflow = oflow(:,:,:,1:numFrames);
        end
    end

    % The network expects the video and optical flow input in 
    % the following dlarray format: 
    % "SSSCB" ==> Height x Width x Frames x Channels x Batch
    %
    % Permute the data 
    %  from
    %      Height x Width x Channels x Frames
    %  to 
    %      Height x Width x Frames x Channels
    video = permute(video, [1,2,4,3]);
    oflow = permute(oflow, [1,2,4,3]);

    data = {video, oflow, label};

    batchesRead = batchesRead + 1;

    userdata.batchesRead = batchesRead;

    % Set the done flag to true, if the reader has read all the frames or
    % if it is training.
    done = batchesRead == numBatches || ~isDataForValidation;
end

function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow)
    rgb = augmentFcn(rgb);
    gray = rgb2gray(rgb);
    flow = estimateFlow(opticalFlow,gray);
    vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy);

    rgb = imcrop(rgb, cropWindow);
    vxvy = imcrop(vxvy, cropWindow);
    vxvy = vxvy(:,:,1:2);
end

function [label,fname] = getLabelFilename(filename)
    [folder,name,ext] = fileparts(string(filename));
    [~,label] = fileparts(folder);
    fname = name + ext;
    label = string(label);
    fname = string(fname);
end

function [totalFrames,label] = getTotalFramesAndLabel(info, filename)
    filenames = info.Filename;
    frames = info.NumFrames;
    [labelName, fname] = getLabelFilename(filename);
    idx = strcmp(filenames, fullfile(labelName,fname));
    totalFrames = frames(idx);    
    label = categorical(string(labelName), string(info.Classes));
end

augmentTransform

augmentTransform функция создает метод увеличения со случайными коэффициентами сдвига влево-вправо и масштабирования.

function augmentFcn = augmentTransform(sz)
% Randomly flip and scale the image.
tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput');

augmentFcn = @(data)augmentData(data,tform,rout);

    function data = augmentData(data,tform,rout)
        data = imwarp(data,tform,'OutputView',rout);
    end
end

modelGradients

modelGradients функция принимает в качестве входных данных мини-пакет данных RGB dlRGB, соответствующие данные оптического потока dlFlowи соответствующая цель dlYи возвращает соответствующие потери, градиенты потерь относительно обучаемых параметров и точность обучения. Чтобы вычислить градиенты, вычислите modelGradients с помощью функции dlfeval функция в обучающем цикле.

function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass video input as RGB and optical flow data through the two-stream
% network.
[dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB);
[dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow);

% Calculate fused loss, gradients, and accuracy for the two-stream
% predictions.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);
% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

gradientsRGB = dlgradient(loss,dlnetRGB.Learnables);
gradientsFlow = dlgradient(loss,dlnetFlow.Learnables);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

acc = gather(extractdata(sum(YTest == YPred)./numel(YTest)));

% Calculate the accuracy of the RGB and flow predictions.
[~,YTest] = max(Y,[],1);
[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest)));
accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest)));
end

doValidation

doValidation функция проверяет сеть с использованием данных проверки.

function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow)

    validationTime = tic;

    numOutputs = 3;
    mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);

    lossValidation = [];
    numClasses = numel(params.Classes);
    cmat = sparse(numClasses,numClasses);
    cmatRGB = sparse(numClasses,numClasses);
    cmatFlow = sparse(numClasses,numClasses);
    while hasdata(mbq)

        [dlX1,dlX2,dlY] = next(mbq);

        [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY);

        lossValidation = [lossValidation,loss];
        cmat = aggregateConfusionMetric(cmat,YTest,YPred);
        cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB);
        cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow);
    end
    lossValidation = mean(lossValidation);
    accValidation = sum(diag(cmat))./sum(cmat,"all");
    accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all");
    accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all");

    validationTime = toc(validationTime);
end

predictValidation

predictValidation функция вычисляет значения потерь и прогнозирования, используя предоставленные dlnetwork объекты для данных RGB и оптического потока.

function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y)

% Pass the video input through the two-stream
% network.
dlYPredRGB = predict(dlnetRGB,dlRGB);
dlYPredFlow = predict(dlnetFlow,dlFlow);

% Calculate the cross-entropy separately for the two-stream
% outputs.
rgbLoss = crossentropy(dlYPredRGB,Y);
flowLoss = crossentropy(dlYPredFlow,Y);

% Fuse the losses.
loss = mean([rgbLoss,flowLoss]);

% Fuse the predictions by calculating the average of the predictions.
dlYPred = (dlYPredRGB + dlYPredFlow)/2;

% Calculate the accuracy of the predictions.
[~,YTest] = max(Y,[],1);
[~,YPred] = max(dlYPred,[],1);

[~,YPredRGB] = max(dlYPredRGB,[],1);
[~,YPredFlow] = max(dlYPredFlow,[],1);

end

updateDlnetwork

updateDlnetwork функция обновляет предоставленную dlnetwork объект с градиентами и другими параметрами с использованием функции оптимизации SGDM sgdmupdate.

function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration)
    % Determine the learning rate using the cosine-annealing learning rate schedule.
    learnRate = cosineAnnealingLearnRate(iteration, params);

    % Apply L2 regularization to the weights.
    idx = dlnet.Learnables.Parameter == "Weights";
    gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:));

    % Update the network parameters using the SGDM optimizer.
    [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum);
end

cosineAnnealingLearnRate

cosineAnnealingLearnRate функция вычисляет скорость обучения на основе текущего числа итераций, минимальной скорости обучения, максимальной скорости обучения и количества итераций для отжига [3].

function lr = cosineAnnealingLearnRate(iteration, params)
    if iteration == params.NumIterations
        lr = params.MinLearningRate;
        return;
    end
    cosineNumIter = [0, params.CosineNumIterations];
    csum = cumsum(cosineNumIter);
    block = find(csum >= iteration, 1,'first');
    cosineIter = iteration - csum(block - 1);
    annealingIteration = mod(cosineIter, cosineNumIter(block));
    cosineIteration = cosineNumIter(block);
    minR = params.MinLearningRate;
    maxR = params.MaxLearningRate;
    cosMult = 1 + cos(pi * annealingIteration / cosineIteration);
    lr = minR + ((maxR - minR) *  cosMult / 2);
end

aggregateConfusionMetric

aggregateConfusionMetric функция постепенно заполняет матрицу путаницы на основе прогнозируемых результатов YPred и ожидаемые результаты YTest.

function cmat = aggregateConfusionMetric(cmat,YTest,YPred)
YTest = gather(extractdata(YTest));
YPred = gather(extractdata(YPred));
[m,n] = size(cmat);
cmat = cmat + full(sparse(YTest,YPred,1,m,n));
end

createMiniBatchQueue

createMiniBatchQueue функция создает minibatchqueue объект, обеспечивающий miniBatchSize объем данных из данного хранилища данных. Он также создает DispatchInBackgroundDatastore если открыт параллельный пул.

function mbq = createMiniBatchQueue(datastore, numOutputs, params)
if params.DispatchInBackground && isempty(gcp('nocreate'))
    % Start a parallel pool, if DispatchInBackground is true, to dispatch
    % data in the background using the parallel pool.
    c = parcluster('local');
    c.NumWorkers = params.NumWorkers;
    parpool('local',params.NumWorkers);
end
p = gcp('nocreate');
if ~isempty(p)
    datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers);
end
inputFormat(1:numOutputs-1) = "SSSCB";
outputFormat = "CB";
mbq = minibatchqueue(datastore, numOutputs, ...
    "MiniBatchSize", params.MiniBatchSize, ...
    "MiniBatchFcn", @batchRGBAndFlow, ...
    "MiniBatchFormat", [inputFormat,outputFormat]);
end

batchRGBAndFlow

batchRGBAndFlow функция группирует данные изображения, потока и метки в соответствующие dlarray значения в форматах данных "SSSCB", "SSSCB", и "CB"соответственно.

function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels)
% Batch dimension: 5
X1 = cat(5,images{:});
X2 = cat(5,flows{:});

% Batch dimension: 2
labels = cat(2,labels{:});

% Feature dimension: 1
Y = onehotencode(labels,1);

% Cast data to single for processing.
X1 = single(X1);
X2 = single(X2);
Y = single(Y);

% Move data to the GPU if possible.
if canUseGPU
    X1 = gpuArray(X1);
    X2 = gpuArray(X2);
    Y = gpuArray(Y);
end

% Return X and Y as dlarray objects.
dlX1 = dlarray(X1,"SSSCB");
dlX2 = dlarray(X2,"SSSCB");
dlY = dlarray(Y,"CB");
end

shuffleTrainDs

shuffleTrainDs функция выполняет тасование файлов, имеющихся в хранилище данных обучения dsTrain.

function shuffled = shuffleTrainDs(dsTrain)
shuffled = copy(dsTrain);
n = numel(shuffled.Files);
shuffledIndices = randperm(n);
shuffled.Files = shuffled.Files(shuffledIndices);
reset(shuffled);
end

saveData

saveData функция сохраняет заданное dlnetwork объекты и значения точности для файла MAT.

function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation)
dlnetRGB = gatherFromGPUToSave(dlnetRGB);
dlnetFlow = gatherFromGPUToSave(dlnetFlow);
data.ValidationAccuracy = accValidation;
data.cmat = cmat;
data.dlnetRGB = dlnetRGB;
data.dlnetFlow = dlnetFlow;
save(modelFilename, 'data');
end

gatherFromGPUToSave

gatherFromGPUToSave собирает данные из графического процессора для сохранения модели на диске.

function dlnet = gatherFromGPUToSave(dlnet)
if ~canUseGPU
    return;
end
dlnet.Learnables = gatherValues(dlnet.Learnables);
dlnet.State = gatherValues(dlnet.State);
    function tbl = gatherValues(tbl)
        for ii = 1:height(tbl)
            tbl.Value{ii} = gather(tbl.Value{ii});
        end
    end
end

checkForHMDB51Folder

checkForHMDB51Folder функция проверяет загруженные данные в папке загрузки.

function classes = checkForHMDB51Folder(dataLoc)
hmdbFolder = fullfile(dataLoc, "hmdb51_org");
if ~exist(hmdbFolder, "dir")
    error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");    
end

classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",...
    "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",...
    "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",...
    "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",...
    "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",...
    "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",...
    "sword_exercise","talk","throw","turn","walk","wave"];
expectFolders = fullfile(hmdbFolder, classes);
if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders))
    error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file.");
end
end

downloadHMDB51

downloadHMDB51 функция загружает набор данных и сохраняет его в каталог.

function downloadHMDB51(dataLoc)

if nargin == 0
    dataLoc = pwd;
end
dataLoc = string(dataLoc);

if ~exist(dataLoc,"dir")
    mkdir(dataLoc);
end

dataUrl     = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar";
options     = weboptions('Timeout', Inf);
rarFileName = fullfile(dataLoc, 'hmdb51_org.rar');
fileExists  = exist(rarFileName, 'file');

% Download the RAR file and save it to the download folder.
if ~fileExists
    disp("Downloading hmdb51_org.rar (2 GB) to the folder:")
    disp(dataLoc)
    disp("This download can take a few minutes...") 
    websave(rarFileName, dataUrl, options); 
    disp("Download complete.")
    disp("Extract the hmdb51_org.rar file contents to the folder: ") 
    disp(dataLoc)
end
end

initializeTrainingProgressPlot

initializeTrainingProgressPlot функция настраивает два графика для отображения потерь при обучении, точности обучения и точности проверки.

function plotters = initializeTrainingProgressPlot(params)
if params.ProgressPlot
    % Plot the loss, training accuracy, and validation accuracy.
    figure
    
    % Loss plot
    subplot(2,1,1)
    plotters.LossPlotter = animatedline;
    xlabel("Iteration")
    ylabel("Loss")
    
    % Accuracy plot
    subplot(2,1,2)
    plotters.TrainAccPlotter = animatedline('Color','b');
    plotters.ValAccPlotter = animatedline('Color','g');
    legend('Training Accuracy','Validation Accuracy','Location','northwest');
    xlabel("Iteration")
    ylabel("Accuracy")
else
    plotters = [];
end
end

initializeVerboseOutput

initializeVerboseOutput отображает заголовки столбцов для таблицы учебных значений, в которой показаны эпоха, точность мини-партии и другие учебные значения.

function initializeVerboseOutput(params)
if params.Verbose
    disp(" ")
    if canUseGPU
        disp("Training on GPU.")
    else
        disp("Training on CPU.")
    end
    p = gcp('nocreate');
    if ~isempty(p)
        disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ")
    end
    disp("NumIterations:" + string(params.NumIterations));
    disp("MiniBatchSize:" + string(params.MiniBatchSize));
    disp("Classes:" + join(string(params.Classes), ","));    
    disp("|=======================================================================================================================================================================|")
    disp("| Epoch | Iteration | Time Elapsed |     Mini-Batch Accuracy    |    Validation Accuracy     | Mini-Batch | Validation |  Base Learning  | Train Time | Validation Time |")
    disp("|       |           |  (hh:mm:ss)  |       (Avg:RGB:Flow)       |       (Avg:RGB:Flow)       |    Loss    |    Loss    |      Rate       | (hh:mm:ss) |   (hh:mm:ss)    |")
    disp("|=======================================================================================================================================================================|")
end
end

displayVerboseOutputEveryEpoch

displayVerboseOutputEveryEpoch функция отображает подробный вывод учебных значений, таких как эпоха, точность мини-партии, точность проверки и потеря мини-партии.

function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,...
        accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime)
    if params.Verbose
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        trainTime = duration(0,0,trainTime,'Format','hh:mm:ss');
        validationTime = duration(0,0,validationTime,'Format','hh:mm:ss');

        lossValidation = gather(extractdata(lossValidation));
        lossValidation = compose('%.4f',lossValidation);

        accValidation = composePadAccuracy(accValidation);
        accValidationRGB = composePadAccuracy(accValidationRGB);
        accValidationFlow = composePadAccuracy(accValidationFlow);

        accVal = join([accValidation,accValidationRGB,accValidationFlow], " : ");

        lossTrain = gather(extractdata(lossTrain));
        lossTrain = compose('%.4f',lossTrain);

        accTrain = composePadAccuracy(accTrain);
        accTrainRGB = composePadAccuracy(accTrainRGB);
        accTrainFlow = composePadAccuracy(accTrainFlow);

        accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : ");
        learnRate = compose('%.13f',learnRate);

        disp("| " + ...
            pad(string(epoch),5,'both') + " | " + ...
            pad(string(iteration),9,'both') + " | " + ...
            pad(string(D),12,'both') + " | " + ...
            pad(string(accTrain),26,'both') + " | " + ...
            pad(string(accVal),26,'both') + " | " + ...
            pad(string(lossTrain),10,'both') + " | " + ...
            pad(string(lossValidation),10,'both') + " | " + ...
            pad(string(learnRate),13,'both') + " | " + ...
            pad(string(trainTime),10,'both') + " | " + ...
            pad(string(validationTime),15,'both') + " |")
    end

end

function acc = composePadAccuracy(acc)
    acc = compose('%.2f',acc*100) + "%";
    acc = pad(string(acc),6,'left');
end

endVerboseOutput

endVerboseOutput отображает конец подробных выходных данных во время обучения.

function endVerboseOutput(params)
if params.Verbose
    disp("|=======================================================================================================================================================================|")        
end
end

updateProgressPlot

updateProgressPlot функция обновляет график хода выполнения с информацией о потерях и точности во время обучения.

function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation)
if params.ProgressPlot
    
    % Update the training progress.
    D = duration(0,0,toc(start),"Format","hh:mm:ss");
    title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D));
    addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain))));
    addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain);
    addpoints(plotters.ValAccPlotter,iteration,accuracyValidation);
    drawnow
end
end

Ссылки

[1] Каррейра, Жоао и Эндрю Зиссерман. "Кво Вадис, признание действий? Новая модель и набор данных по кинетике. " Материалы Конференции IEEE по компьютерному зрению и распознаванию образов (CVPR): 6299?? 6308. Гонолулу, HI: IEEE, 2017.

[2] Симоньян, Карен и Эндрю Зиссерман. «Двухстримовые сверточные сети для распознавания действий в видео». Достижения в системах обработки нейронной информации 27, Лонг-Бич, Калифорния: NIPS, 2017.

[3] Лошчилов, Илья и Фрэнк Хаттер. «SGDR: Стохастический градиентный спуск с теплыми перезапусками». Международная конференция по учебным представлениям 2017. Тулон, Франция: ICLR, 2017.