В этом примере показано, как обучить двухпотоковую сверточную нейронную сеть Fluted 3-D (I3D) распознаванию активности с использованием RGB и данных оптического потока из видео [1].
Распознавание активности на основе зрения включает в себя прогнозирование действия объекта, такого как ходьба, плавание или сидение, с использованием набора видеокадров. Распознавание активности из видео имеет множество приложений, таких как взаимодействие человека с компьютером, обучение роботов, обнаружение аномалий, наблюдение и обнаружение объектов. Например, онлайн-прогнозирование нескольких действий для входящих видео с нескольких камер может быть важным для обучения роботов. По сравнению с классификацией изображений, распознавание действий с использованием видео является сложной задачей для моделирования из-за шумных меток в наборах видеоданных, разнообразия действий, которые актеры в видео могут выполнять, которые сильно разбалансированы классом, и неэффективности вычислений при предварительной подготовке на больших наборах видеоданных. Некоторые методы глубокого обучения, такие как I3D двухпотоковые сверточные сети [1], показали улучшенную производительность за счет использования предварительной подготовки на больших наборах данных классификации изображений.
В этом примере выполняется обучение сети I3D с использованием набора данных HMDB51. Используйте downloadHMDB51 вспомогательная функция, перечисленная в конце этого примера, для загрузки набора данных HMDB51 в папку с именем hmdb51.
downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);После завершения загрузки извлеките файл RAR. hmdb51_org.rar в hmdb51 папка. Далее используйте checkForHMDB51Folder вспомогательная функция, перечисленная в конце этого примера, для подтверждения наличия загруженных и извлеченных файлов.
allClasses = checkForHMDB51Folder(downloadFolder);
Набор данных содержит около 2 ГБ видеоданных для 7000 клипов более 51 класса, таких как выпивка, бег и пожимание рук. Каждый видеокадр имеет высоту 240 пикселей и минимальную ширину 176 пикселей. Количество кадров колеблется от 18 до приблизительно 1000.
Чтобы сократить время обучения, в этом примере обучается сеть распознавания действий для классификации 5 классов действий вместо всех 51 класса в наборе данных. Набор useAllData кому true тренироваться со всеми 51 классами.
useAllData = false; if useAllData classes = allClasses; else classes = ["kiss","laugh","pick","pour","pushup"]; end dataFolder = fullfile(downloadFolder, "hmdb51_org");
Разбейте набор данных на обучающий набор для обучения сети и тестовый набор для оценки сети. Используйте 80% данных для обучающего набора, а остальные - для тестового набора. Использовать imageDatastore разделить данные на основе каждой метки на наборы обучающих и тестовых данных путем случайного выбора доли файлов из каждой метки.
imds = imageDatastore(fullfile(dataFolder,classes),... 'IncludeSubfolders', true,... 'LabelSource', 'foldernames',... 'FileExtensions', '.avi'); [trainImds,testImds] = splitEachLabel(imds,0.8,'randomized'); trainFilenames = trainImds.Files; testFilenames = testImds.Files;
Для нормализации входных данных для сети в файле MAT предоставляются минимальное и максимальное значения для набора данных. inputStatistics.mat, прилагается к этому примеру. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте inputStatistics вспомогательная функция, перечисленная в конце этого примера.
inputStatsFilename = 'inputStatistics.mat'; if ~exist(inputStatsFilename, 'file') disp("Reading all the training data for input statistics...") inputStats = inputStatistics(dataFolder); else d = load(inputStatsFilename); inputStats = d.inputStats; end
Создать два FileDatastore объекты для обучения и проверки с использованием createFileDatastore вспомогательная функция, определенная в конце этого примера. Каждое хранилище данных считывает видеофайл для предоставления данных RGB, данных оптического потока и соответствующей информации метки.
Укажите количество кадров для каждого чтения хранилищем данных. Типичными значениями являются 16, 32, 64 или 128. Использование большего количества кадров помогает захватить больше временной информации, но требует больше памяти для обучения и прогнозирования. Установите для количества кадров значение 64, чтобы сбалансировать использование памяти с производительностью. Возможно, потребуется снизить это значение в зависимости от системных ресурсов.
numFrames = 64;
Укажите высоту и ширину кадров для считываемого хранилища данных. Установка одинаковых значений высоты и ширины упрощает пакетирование данных для сети. Типичными значениями являются [112, 112], [224, 224] и [256, 256]. Минимальная высота и ширина видеокадров в наборе данных HMDB51 составляют 240 и 176 соответственно. Определите [112, 112], чтобы захватить большее число структур за счет пространственной информации. Если требуется указать размер кадра для считываемого хранилища данных, превышающий минимальные значения, например [256, 256], сначала измените размер кадров с помощью imresize.
frameSize = [112,112];
Набор inputSize в inputStats структура таким образом функция считывания fileDatastore может считывать указанный размер ввода.
inputSize = [frameSize, numFrames]; inputStats.inputSize = inputSize; inputStats.Classes = classes;
Создать два FileDatastore объекты, один для обучения и другой для проверки.
isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);
isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);
disp("Training data size: " + string(numel(dsTrain.Files)))Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))Validation data size: 109
Использование 3-D CNN является естественным подходом к извлечению пространственно-временных особенностей из видео. Вы можете создать сеть I3D из предварительно обученной 2-й сети классификации изображений, такой как Начало v1 или ResNet-50, расширив 2-е фильтры и объединив ядра в 3D. Эта процедура повторно использует веса, полученные из задачи классификации образов, для начальной загрузки задачи распознавания видео.
На следующем рисунке показан пример раздувания слоя свертки 2-D до слоя свертки 3-D. Инфляция включает в себя расширение размера фильтра, весов и смещения путем добавления третьего измерения (временного измерения).

Можно считать, что видеоданные имеют две части: пространственную и временную.
Пространственный компонент содержит информацию о форме, текстуре и цвете объектов в видео. Данные RGB содержат эту информацию.
Временная составляющая содержит информацию о движении объектов по кадрам и изображает важные перемещения между камерой и объектами в сцене. Вычисление оптического потока является обычным способом извлечения временной информации из видео.
Двухпоточный CNN включает пространственную подсеть и временную подсеть [2]. Сверточная нейронная сеть, обученная плотному оптическому потоку и потоку видеоданных, может достичь лучшей производительности при ограниченных обучающих данных, чем при необработанных кадрах RGB. На следующем рисунке показана типичная двухпотоковая сеть I3D.

В этом примере создается I3D сеть с использованием GoogLeNet, сети, предварительно подготовленной в базе данных ImageNet.
Укажите количество каналов как 3 для подсети RGB, и 2 для подсети оптического потока. Два канала для данных оптического потока являются x и компонентами скорости, и .
rgbChannels = 3; flowChannels = 2;
Получение минимального и максимального значений для RGB и данных оптического потока из inputStats структура загружена из inputStatistics.mat файл. Эти значения необходимы для image3dInputLayer I3D сетей для нормализации входных данных.
rgbInputSize = [frameSize, numFrames, rgbChannels]; flowInputSize = [frameSize, numFrames, flowChannels]; rgbMin = inputStats.rgbMin; rgbMax = inputStats.rgbMax; oflowMin = inputStats.oflowMin(:,:,1:2); oflowMax = inputStats.oflowMax(:,:,1:2); rgbMin = reshape(rgbMin,[1,size(rgbMin)]); rgbMax = reshape(rgbMax,[1,size(rgbMax)]); oflowMin = reshape(oflowMin,[1,size(oflowMin)]); oflowMax = reshape(oflowMax,[1,size(oflowMax)]);
Укажите количество занятий для обучения сети.
numClasses = numel(classes);
Создайте подсети I3D RGB и оптического потока с помощью Inflated3D вспомогательная функция, которая присоединена к этому примеру. Подсети создаются из GoogLeNet.
cnnNet = googlenet; netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet); netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);
Создать dlnetwork объект из графа уровней каждой из I3D сетей.
dlnetRGB = dlnetwork(netRGB); dlnetFlow = dlnetwork(netFlow);
Создание вспомогательной функции modelGradients, перечисленных в конце этого примера. modelGradients функция принимает в качестве входных данных подсеть RGB dlnetRGB, подсеть оптического потока dlnetFlow, мини-пакет входных данных dlRGB и dlFlowи мини-пакет данных метки истинности земли dlY. Функция возвращает значение обучающих потерь, градиенты потерь относительно обучаемых параметров соответствующих подсетей и точность мини-пакета подсетей.
Потери вычисляются путем вычисления среднего значения потерь перекрестной энтропии предсказаний из каждой из подсетей. Выходные прогнозы сети являются вероятностями между 0 и 1 для каждого из классов.
reyPrediction)
flowPrediction)
flowLoss])
Точность каждой из подсетей вычисляется путем взятия среднего значения предсказаний RGB и оптического потока и сравнения его с нулевой меткой истинности входных данных.
Поезд с размером мини-партии 20 на 1500 итераций. Укажите итерацию, после которой сохранить модель с наилучшей точностью проверки с помощью SaveBestAfterIteration параметр.
Укажите параметры графика скорости обучения косинусному отжигу [3]. Для обеих сетей используйте:
Минимальная скорость обучения 1e-4.
Максимальная скорость обучения 1e-3.
Косинусное число итераций 300, 500 и 700, после чего цикл планирования скорости обучения перезапускается. Выбор CosineNumIterations определяет ширину каждого косинусного цикла.
Укажите параметры для оптимизации SGDM. Инициализируйте параметры оптимизации SGDM в начале обучения для каждой из сетей RGB и оптических потоков. Для обеих сетей используйте:
Импульс 0,9.
Начальный параметр скорости, инициализированный как [].
Коэффициент регуляции L2 0,0005.
Используется для отправки данных в фоновом режиме с использованием параллельного пула. Если DispatchInBackground имеет значение true, открывает параллельный пул с указанным числом параллельных работников и создает DispatchInBackgroundDatastore, предоставленный в качестве части этого примера, который отправляет данные в фоновом режиме для ускорения обучения с использованием асинхронной загрузки и предварительной обработки данных. По умолчанию в этом примере используется графический процессор, если он доступен. В противном случае используется ЦП. Для использования графического процессора требуются параллельные вычислительные Toolbox™ и графический процессор NVIDIA ® с поддержкой CUDA ®. Сведения о поддерживаемых вычислительных возможностях см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).
params.Classes = classes; params.MiniBatchSize = 20; params.NumIterations = 1500; params.SaveBestAfterIteration = 900; params.CosineNumIterations = [300, 500, 700]; params.MinLearningRate = 1e-4; params.MaxLearningRate = 1e-3; params.Momentum = 0.9; params.VelocityRGB = []; params.VelocityFlow = []; params.L2Regularization = 0.0005; params.ProgressPlot = false; params.Verbose = true; params.ValidationData = dsVal; params.DispatchInBackground = false; params.NumWorkers = 4;
Обучение подсетей с использованием данных RGB и данных оптического потока. Установите doTraining переменная для false для загрузки предварительно подготовленных подсетей без необходимости ожидания завершения обучения. Кроме того, если требуется обучить подсети, установите doTraining переменная для true.
doTraining = false;
Для каждой эпохи:
Перетасовка данных перед закольцовыванием по мини-пакетам данных.
Использовать minibatchqueue для закольцовывания мини-партий. Вспомогательная функция createMiniBatchQueue, перечисленных в конце этого примера, использует данное хранилище данных обучения для создания minibatchqueue.
Использовать данные проверки dsVal для проверки сетей.
Отображение результатов потерь и точности для каждой эпохи с помощью вспомогательной функции displayVerboseOutputEveryEpoch, перечисленных в конце этого примера.
Для каждой мини-партии:
Преобразование данных изображения или данных оптического потока и меток в dlarray объекты с базовым типом одиночный.
Рассматривать временную размерность видеоданных и данных оптического потока как одну из пространственных размерностей для обеспечения возможности обработки с использованием 3-D CNN. Указание меток размеров "SSSCB" (пространственные, пространственные, пространственные, канальные, пакетные) для данных RGB или оптического потока, и "CB" для данных метки.
minibatchqueue объект использует вспомогательную функцию batchRGBAndFlow, перечисленных в конце этого примера, для пакетной обработки данных RGB и оптического потока.
modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat"; if doTraining epoch = 1; bestValAccuracy = 0; accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; iteration = 1; shuffled = shuffleTrainDs(dsTrain); % Number of outputs is three: One for RGB frames, one for optical flow % data, and one for ground truth labels. numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); start = tic; trainTime = start; % Use the initializeTrainingProgressPlot and initializeVerboseOutput % supporting functions, listed at the end of the example, to initialize % the training progress plot and verbose output to display the training % loss, training accuracy, and validation accuracy. plotters = initializeTrainingProgressPlot(params); initializeVerboseOutput(params); while iteration <= params.NumIterations % Iterate through the data set. [dlX1,dlX2,dlY] = next(mbq); % Evaluate the model gradients and loss using dlfeval. [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ... dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); % Accumulate the loss and accuracies. lossTrain = [lossTrain, loss]; accTrain = [accTrain, acc]; accTrainRGB = [accTrainRGB, accRGB]; accTrainFlow = [accTrainFlow, accFlow]; % Update the network state. dlnetRGB.State = stateRGB; dlnetFlow.State = stateFlow; % Update the gradients and parameters for the RGB and optical flow % subnetworks using the SGDM optimizer. [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ... updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration); [dlnetFlow,gradFlow,params.VelocityFlow] = ... updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration); if ~hasdata(mbq) || iteration == params.NumIterations % Current epoch is complete. Do validation and update progress. trainTime = toc(trainTime); [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ... doValidation(params, dlnetRGB, dlnetFlow); % Update the training progress. displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),... accValidation,accValidationRGB,accValidationFlow,... mean(lossTrain),lossValidation,trainTime,validationTime); updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation); % Save model with the trained dlnetwork and accuracy values. % Use the saveData supporting function, listed at the % end of this example. if iteration >= params.SaveBestAfterIteration if accValidation > bestValAccuracy bestValAccuracy = accValidation; saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation); end end end if ~hasdata(mbq) && iteration < params.NumIterations % Current epoch is complete. Initialize the training loss, accuracy % values, and minibatchqueue for the next epoch. accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; trainTime = tic; epoch = epoch + 1; shuffled = shuffleTrainDs(dsTrain); numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); end iteration = iteration + 1; end % Display a message when training is complete. endVerboseOutput(params); disp("Model saved to: " + modelFilename); end % Download the pretrained model and video file for prediction. filename = "activityRecognition-I3D-HMDB51.zip"; downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename; filename = fullfile(downloadFolder,filename); if ~exist(filename,'file') disp('Downloading the pretrained network...'); websave(filename,downloadURL); end % Unzip the contents to the download folder. unzip(filename,downloadFolder); if ~doTraining modelFilename = fullfile(downloadFolder, modelFilename); end
Используйте набор тестовых данных для оценки точности обученных подсетей.
Загрузите лучшую модель, сохраненную во время обучения.
d = load(modelFilename); dlnetRGB = d.data.dlnetRGB; dlnetFlow = d.data.dlnetFlow;
Создать minibatchqueue объект для загрузки пакетов тестовых данных.
numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);
Для каждой партии тестовых данных сделайте прогнозы с использованием RGB и оптических потоковых сетей, возьмите среднее значение прогнозов и вычислите точность прогнозирования с использованием матрицы путаницы.
cmat = sparse(numClasses,numClasses); while hasdata(mbq) [dlRGB, dlFlow, dlY] = next(mbq); % Pass the video input as RGB and optical flow data through the % two-stream subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(dlY,[],1); [~,YPred] = max(dlYPred,[],1); cmat = aggregateConfusionMetric(cmat,YTest,YPred); end
Вычислите среднюю точность классификации для обученных сетей.
accuracyEval = sum(diag(cmat))./sum(cmat,"all")accuracyEval =
0.60909
Отображение матрицы путаницы.
figure chart = confusionchart(cmat,classes);

Из-за ограниченного количества обучающих образцов повышение точности выше 61% является сложной задачей. Для повышения надежности сети требуется дополнительное обучение с большим набором данных. Кроме того, предварительная подготовка большего набора данных, например, Kinetics [1], может помочь улучшить результаты.
Теперь можно использовать обученные сети для прогнозирования действий в новых видео. Чтение и отображение видео pour.avi использование VideoReader и vision.VideoPlayer.
videoFilename = fullfile(downloadFolder, "pour.avi"); videoReader = VideoReader(videoFilename); videoPlayer = vision.VideoPlayer; videoPlayer.Name = "pour"; while hasFrame(videoReader) frame = readFrame(videoReader); step(videoPlayer,frame); end release(videoPlayer);

Используйте readRGBAndFlow поддерживающая функция, перечисленная в конце этого примера, для считывания данных RGB и оптического потока.
isDataForValidation = true; readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
Функция считывания возвращает логическое значение isDone значение, указывающее на наличие дополнительных данных для чтения из файла. Используйте batchRGBAndFlow поддерживающая функция, определенная в конце этого примера, для пакетной передачи данных через двухпотоковые подсети для получения прогнозов.
hasdata = true; userdata = []; YPred = []; while hasdata [data,userdata,isDone] = readFcn(videoFilename,userdata); [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3)); % Pass video input as RGB and optical flow data through the two-stream % subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; [~,YPredCurr] = max(dlYPred,[],1); YPred = horzcat(YPred,YPredCurr); hasdata = ~isDone; end YPred = extractdata(YPred);
Подсчитать количество правильных прогнозов с помощью histcountsи получают предсказанное действие, используя максимальное количество правильных прогнозов.
classes = params.Classes; counts = histcounts(YPred,1:numel(classes)); [~,clsIdx] = max(counts); action = classes(clsIdx)
action = "pour"
inputStatistics inputStatistics функция принимает в качестве входных данных имя папки, содержащей HMDB51 данные, и вычисляет минимальное и максимальное значения для данных RGB и данных оптического потока. Минимальное и максимальное значения используются в качестве входных данных нормализации для входного уровня сетей. Эта функция также получает количество кадров в каждом из видеофайлов для последующего использования во время обучения и тестирования сети. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте эту функцию с именем папки, содержащей набор данных.
function inputStats = inputStatistics(dataFolder) ds = createDatastore(dataFolder); ds.ReadFcn = @getMinMax; tic; tt = tall(ds); varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'}; stats = gather(groupsummary(tt,[],{'max','min'}, varnames)); inputStats.Filename = gather(tt.Filename); inputStats.NumFrames = gather(tt.NumFrames); inputStats.rgbMax = stats.max_rgbMax; inputStats.rgbMin = stats.min_rgbMin; inputStats.oflowMax = stats.max_oflowMax; inputStats.oflowMin = stats.min_oflowMin; save('inputStatistics.mat','inputStats'); toc; end function data = getMinMax(filename) reader = VideoReader(filename); opticFlow = opticalFlowFarneback; data = []; while hasFrame(reader) frame = readFrame(reader); [rgb,oflow] = findMinMax(frame,opticFlow); data = assignMinMax(data, rgb, oflow); end totalFrames = floor(reader.Duration * reader.FrameRate); totalFrames = min(totalFrames, reader.NumFrames); [labelName, filename] = getLabelFilename(filename); data.Filename = fullfile(labelName, filename); data.NumFrames = totalFrames; data = struct2table(data,'AsArray',true); end function data = assignMinMax(data, rgb, oflow) if isempty(data) data.rgbMax = rgb.Max; data.rgbMin = rgb.Min; data.oflowMax = oflow.Max; data.oflowMin = oflow.Min; return; end data.rgbMax = max(data.rgbMax, rgb.Max); data.rgbMin = min(data.rgbMin, rgb.Min); data.oflowMax = max(data.oflowMax, oflow.Max); data.oflowMin = min(data.oflowMin, oflow.Min); end function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow) rgbMinMax.Max = max(rgb,[],[1,2]); rgbMinMax.Min = min(rgb,[],[1,2]); gray = rgb2gray(rgb); flow = estimateFlow(opticFlow,gray); oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude); oflowMinMax.Max = max(oflow,[],[1,2]); oflowMinMax.Min = min(oflow,[],[1,2]); end function ds = createDatastore(folder) ds = fileDatastore(folder,... 'IncludeSubfolders', true,... 'FileExtensions', '.avi',... 'UniformRead', true,... 'ReadFcn', @getMinMax); disp("NumFiles: " + numel(ds.Files)); end
createFileDatastore createFileDatastore функция создает FileDatastore с использованием заданных имен файлов. FileDatastore объект считывает данные в 'partialfile' режим, так что каждое чтение может возвращать частично считанные кадры из видео. Эта функция помогает при чтении больших видеофайлов, если все кадры не помещаются в память.
function datastore = createFileDatastore(filenames,inputStats,isDataForValidation) readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation); datastore = fileDatastore(filenames,... 'ReadFcn',readFcn,... 'ReadMode','partialfile'); end
readRGBAndFlow readRGBAndFlow функция считывает кадры RGB, соответствующие данные оптического потока и значения меток для данного видеофайла. Во время обучения функция считывания считывает определенное количество кадров в соответствии с размером сетевого ввода с произвольно выбранным начальным кадром. Данные оптического потока вычисляются с начала видеофайла, но пропускаются до достижения начального кадра. Во время тестирования все кадры последовательно считываются и вычисляются соответствующие данные оптического потока. Кадры RGB и данные оптического потока случайным образом обрезаются до требуемого размера сетевого ввода для обучения, а центр обрезается для тестирования и проверки.
function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation) if isempty(userdata) userdata.reader = VideoReader(filename); userdata.batchesRead = 0; userdata.opticalFlow = opticalFlowFarneback; [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename); if isempty(totalFrames) totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate); totalFrames = min(totalFrames, userdata.reader.NumFrames); end userdata.totalFrames = totalFrames; end reader = userdata.reader; totalFrames = userdata.totalFrames; label = userdata.label; batchesRead = userdata.batchesRead; opticalFlow = userdata.opticalFlow; inputSize = inputStats.inputSize; H = inputSize(1); W = inputSize(2); rgbC = 3; flowC = 2; numFrames = inputSize(3); if numFrames > totalFrames numBatches = 1; else numBatches = floor(totalFrames/numFrames); end imH = userdata.reader.Height; imW = userdata.reader.Width; imsz = [imH,imW]; if ~isDataForValidation augmentFcn = augmentTransform([imsz,3]); cropWindow = randomCropWindow2d(imsz, inputSize(1:2)); % 1. Randomly select required number of frames, % starting randomly at a specific frame. if numFrames >= totalFrames idx = 1:totalFrames; % Add more frames to fill in the network input size. additional = ceil(numFrames/totalFrames); idx = repmat(idx,1,additional); idx = idx(1:numFrames); else startIdx = randperm(totalFrames - numFrames); startIdx = startIdx(1); endIdx = startIdx + numFrames - 1; idx = startIdx:endIdx; end video = zeros(H,W,rgbC,numFrames); oflow = zeros(H,W,flowC,numFrames); i = 1; % Discard the first set of frames to initialize the optical flow. for ii = 1:idx(1)-1 frame = read(reader,ii); getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); end % Read the next set of required number of frames for training. for ii = idx frame = read(reader,ii); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end else augmentFcn = @(data)(data); cropWindow = centerCropWindow2d(imsz, inputSize(1:2)); toRead = min([numFrames,totalFrames]); video = zeros(H,W,rgbC,toRead); oflow = zeros(H,W,flowC,toRead); i = 1; while hasFrame(reader) && i <= numFrames frame = readFrame(reader); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end if numFrames > totalFrames additional = ceil(numFrames/totalFrames); video = repmat(video,1,1,1,additional); oflow = repmat(oflow,1,1,1,additional); video = video(:,:,:,1:numFrames); oflow = oflow(:,:,:,1:numFrames); end end % The network expects the video and optical flow input in % the following dlarray format: % "SSSCB" ==> Height x Width x Frames x Channels x Batch % % Permute the data % from % Height x Width x Channels x Frames % to % Height x Width x Frames x Channels video = permute(video, [1,2,4,3]); oflow = permute(oflow, [1,2,4,3]); data = {video, oflow, label}; batchesRead = batchesRead + 1; userdata.batchesRead = batchesRead; % Set the done flag to true, if the reader has read all the frames or % if it is training. done = batchesRead == numBatches || ~isDataForValidation; end function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow) rgb = augmentFcn(rgb); gray = rgb2gray(rgb); flow = estimateFlow(opticalFlow,gray); vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy); rgb = imcrop(rgb, cropWindow); vxvy = imcrop(vxvy, cropWindow); vxvy = vxvy(:,:,1:2); end function [label,fname] = getLabelFilename(filename) [folder,name,ext] = fileparts(string(filename)); [~,label] = fileparts(folder); fname = name + ext; label = string(label); fname = string(fname); end function [totalFrames,label] = getTotalFramesAndLabel(info, filename) filenames = info.Filename; frames = info.NumFrames; [labelName, fname] = getLabelFilename(filename); idx = strcmp(filenames, fullfile(labelName,fname)); totalFrames = frames(idx); label = categorical(string(labelName), string(info.Classes)); end
augmentTransform augmentTransform функция создает метод увеличения со случайными коэффициентами сдвига влево-вправо и масштабирования.
function augmentFcn = augmentTransform(sz) % Randomly flip and scale the image. tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput'); augmentFcn = @(data)augmentData(data,tform,rout); function data = augmentData(data,tform,rout) data = imwarp(data,tform,'OutputView',rout); end end
modelGradients modelGradients функция принимает в качестве входных данных мини-пакет данных RGB dlRGB, соответствующие данные оптического потока dlFlowи соответствующая цель dlYи возвращает соответствующие потери, градиенты потерь относительно обучаемых параметров и точность обучения. Чтобы вычислить градиенты, вычислите modelGradients с помощью функции dlfeval функция в обучающем цикле.
function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass video input as RGB and optical flow data through the two-stream % network. [dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB); [dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow); % Calculate fused loss, gradients, and accuracy for the two-stream % predictions. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); gradientsRGB = dlgradient(loss,dlnetRGB.Learnables); gradientsFlow = dlgradient(loss,dlnetFlow.Learnables); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); acc = gather(extractdata(sum(YTest == YPred)./numel(YTest))); % Calculate the accuracy of the RGB and flow predictions. [~,YTest] = max(Y,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest))); accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest))); end
doValidation doValidation функция проверяет сеть с использованием данных проверки.
function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow) validationTime = tic; numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params); lossValidation = []; numClasses = numel(params.Classes); cmat = sparse(numClasses,numClasses); cmatRGB = sparse(numClasses,numClasses); cmatFlow = sparse(numClasses,numClasses); while hasdata(mbq) [dlX1,dlX2,dlY] = next(mbq); [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); lossValidation = [lossValidation,loss]; cmat = aggregateConfusionMetric(cmat,YTest,YPred); cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB); cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow); end lossValidation = mean(lossValidation); accValidation = sum(diag(cmat))./sum(cmat,"all"); accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all"); accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all"); validationTime = toc(validationTime); end
predictValidation predictValidation функция вычисляет значения потерь и прогнозирования, используя предоставленные dlnetwork объекты для данных RGB и оптического потока.
function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass the video input through the two-stream % network. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Calculate the cross-entropy separately for the two-stream % outputs. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); end
updateDlnetwork updateDlnetwork функция обновляет предоставленную dlnetwork объект с градиентами и другими параметрами с использованием функции оптимизации SGDM sgdmupdate.
function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration) % Determine the learning rate using the cosine-annealing learning rate schedule. learnRate = cosineAnnealingLearnRate(iteration, params); % Apply L2 regularization to the weights. idx = dlnet.Learnables.Parameter == "Weights"; gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:)); % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum); end
cosineAnnealingLearnRate cosineAnnealingLearnRate функция вычисляет скорость обучения на основе текущего числа итераций, минимальной скорости обучения, максимальной скорости обучения и количества итераций для отжига [3].
function lr = cosineAnnealingLearnRate(iteration, params) if iteration == params.NumIterations lr = params.MinLearningRate; return; end cosineNumIter = [0, params.CosineNumIterations]; csum = cumsum(cosineNumIter); block = find(csum >= iteration, 1,'first'); cosineIter = iteration - csum(block - 1); annealingIteration = mod(cosineIter, cosineNumIter(block)); cosineIteration = cosineNumIter(block); minR = params.MinLearningRate; maxR = params.MaxLearningRate; cosMult = 1 + cos(pi * annealingIteration / cosineIteration); lr = minR + ((maxR - minR) * cosMult / 2); end
aggregateConfusionMetric aggregateConfusionMetric функция постепенно заполняет матрицу путаницы на основе прогнозируемых результатов YPred и ожидаемые результаты YTest.
function cmat = aggregateConfusionMetric(cmat,YTest,YPred) YTest = gather(extractdata(YTest)); YPred = gather(extractdata(YPred)); [m,n] = size(cmat); cmat = cmat + full(sparse(YTest,YPred,1,m,n)); end
createMiniBatchQueue createMiniBatchQueue функция создает minibatchqueue объект, обеспечивающий miniBatchSize объем данных из данного хранилища данных. Он также создает DispatchInBackgroundDatastore если открыт параллельный пул.
function mbq = createMiniBatchQueue(datastore, numOutputs, params) if params.DispatchInBackground && isempty(gcp('nocreate')) % Start a parallel pool, if DispatchInBackground is true, to dispatch % data in the background using the parallel pool. c = parcluster('local'); c.NumWorkers = params.NumWorkers; parpool('local',params.NumWorkers); end p = gcp('nocreate'); if ~isempty(p) datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers); end inputFormat(1:numOutputs-1) = "SSSCB"; outputFormat = "CB"; mbq = minibatchqueue(datastore, numOutputs, ... "MiniBatchSize", params.MiniBatchSize, ... "MiniBatchFcn", @batchRGBAndFlow, ... "MiniBatchFormat", [inputFormat,outputFormat]); end
batchRGBAndFlow batchRGBAndFlow функция группирует данные изображения, потока и метки в соответствующие dlarray значения в форматах данных "SSSCB", "SSSCB", и "CB"соответственно.
function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels) % Batch dimension: 5 X1 = cat(5,images{:}); X2 = cat(5,flows{:}); % Batch dimension: 2 labels = cat(2,labels{:}); % Feature dimension: 1 Y = onehotencode(labels,1); % Cast data to single for processing. X1 = single(X1); X2 = single(X2); Y = single(Y); % Move data to the GPU if possible. if canUseGPU X1 = gpuArray(X1); X2 = gpuArray(X2); Y = gpuArray(Y); end % Return X and Y as dlarray objects. dlX1 = dlarray(X1,"SSSCB"); dlX2 = dlarray(X2,"SSSCB"); dlY = dlarray(Y,"CB"); end
shuffleTrainDs shuffleTrainDs функция выполняет тасование файлов, имеющихся в хранилище данных обучения dsTrain.
function shuffled = shuffleTrainDs(dsTrain) shuffled = copy(dsTrain); n = numel(shuffled.Files); shuffledIndices = randperm(n); shuffled.Files = shuffled.Files(shuffledIndices); reset(shuffled); end
saveData saveData функция сохраняет заданное dlnetwork объекты и значения точности для файла MAT.
function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation) dlnetRGB = gatherFromGPUToSave(dlnetRGB); dlnetFlow = gatherFromGPUToSave(dlnetFlow); data.ValidationAccuracy = accValidation; data.cmat = cmat; data.dlnetRGB = dlnetRGB; data.dlnetFlow = dlnetFlow; save(modelFilename, 'data'); end
gatherFromGPUToSave gatherFromGPUToSave собирает данные из графического процессора для сохранения модели на диске.
function dlnet = gatherFromGPUToSave(dlnet) if ~canUseGPU return; end dlnet.Learnables = gatherValues(dlnet.Learnables); dlnet.State = gatherValues(dlnet.State); function tbl = gatherValues(tbl) for ii = 1:height(tbl) tbl.Value{ii} = gather(tbl.Value{ii}); end end end
checkForHMDB51Folder checkForHMDB51Folder функция проверяет загруженные данные в папке загрузки.
function classes = checkForHMDB51Folder(dataLoc) hmdbFolder = fullfile(dataLoc, "hmdb51_org"); if ~exist(hmdbFolder, "dir") error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",... "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",... "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",... "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",... "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",... "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",... "sword_exercise","talk","throw","turn","walk","wave"]; expectFolders = fullfile(hmdbFolder, classes); if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders)) error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end end
downloadHMDB51 downloadHMDB51 функция загружает набор данных и сохраняет его в каталог.
function downloadHMDB51(dataLoc) if nargin == 0 dataLoc = pwd; end dataLoc = string(dataLoc); if ~exist(dataLoc,"dir") mkdir(dataLoc); end dataUrl = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"; options = weboptions('Timeout', Inf); rarFileName = fullfile(dataLoc, 'hmdb51_org.rar'); fileExists = exist(rarFileName, 'file'); % Download the RAR file and save it to the download folder. if ~fileExists disp("Downloading hmdb51_org.rar (2 GB) to the folder:") disp(dataLoc) disp("This download can take a few minutes...") websave(rarFileName, dataUrl, options); disp("Download complete.") disp("Extract the hmdb51_org.rar file contents to the folder: ") disp(dataLoc) end end
initializeTrainingProgressPlot initializeTrainingProgressPlot функция настраивает два графика для отображения потерь при обучении, точности обучения и точности проверки.
function plotters = initializeTrainingProgressPlot(params) if params.ProgressPlot % Plot the loss, training accuracy, and validation accuracy. figure % Loss plot subplot(2,1,1) plotters.LossPlotter = animatedline; xlabel("Iteration") ylabel("Loss") % Accuracy plot subplot(2,1,2) plotters.TrainAccPlotter = animatedline('Color','b'); plotters.ValAccPlotter = animatedline('Color','g'); legend('Training Accuracy','Validation Accuracy','Location','northwest'); xlabel("Iteration") ylabel("Accuracy") else plotters = []; end end
initializeVerboseOutput initializeVerboseOutput отображает заголовки столбцов для таблицы учебных значений, в которой показаны эпоха, точность мини-партии и другие учебные значения.
function initializeVerboseOutput(params) if params.Verbose disp(" ") if canUseGPU disp("Training on GPU.") else disp("Training on CPU.") end p = gcp('nocreate'); if ~isempty(p) disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ") end disp("NumIterations:" + string(params.NumIterations)); disp("MiniBatchSize:" + string(params.MiniBatchSize)); disp("Classes:" + join(string(params.Classes), ",")); disp("|=======================================================================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-Batch Accuracy | Validation Accuracy | Mini-Batch | Validation | Base Learning | Train Time | Validation Time |") disp("| | | (hh:mm:ss) | (Avg:RGB:Flow) | (Avg:RGB:Flow) | Loss | Loss | Rate | (hh:mm:ss) | (hh:mm:ss) |") disp("|=======================================================================================================================================================================|") end end
displayVerboseOutputEveryEpoch displayVerboseOutputEveryEpoch функция отображает подробный вывод учебных значений, таких как эпоха, точность мини-партии, точность проверки и потеря мини-партии.
function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime) if params.Verbose D = duration(0,0,toc(start),'Format','hh:mm:ss'); trainTime = duration(0,0,trainTime,'Format','hh:mm:ss'); validationTime = duration(0,0,validationTime,'Format','hh:mm:ss'); lossValidation = gather(extractdata(lossValidation)); lossValidation = compose('%.4f',lossValidation); accValidation = composePadAccuracy(accValidation); accValidationRGB = composePadAccuracy(accValidationRGB); accValidationFlow = composePadAccuracy(accValidationFlow); accVal = join([accValidation,accValidationRGB,accValidationFlow], " : "); lossTrain = gather(extractdata(lossTrain)); lossTrain = compose('%.4f',lossTrain); accTrain = composePadAccuracy(accTrain); accTrainRGB = composePadAccuracy(accTrainRGB); accTrainFlow = composePadAccuracy(accTrainFlow); accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : "); learnRate = compose('%.13f',learnRate); disp("| " + ... pad(string(epoch),5,'both') + " | " + ... pad(string(iteration),9,'both') + " | " + ... pad(string(D),12,'both') + " | " + ... pad(string(accTrain),26,'both') + " | " + ... pad(string(accVal),26,'both') + " | " + ... pad(string(lossTrain),10,'both') + " | " + ... pad(string(lossValidation),10,'both') + " | " + ... pad(string(learnRate),13,'both') + " | " + ... pad(string(trainTime),10,'both') + " | " + ... pad(string(validationTime),15,'both') + " |") end end function acc = composePadAccuracy(acc) acc = compose('%.2f',acc*100) + "%"; acc = pad(string(acc),6,'left'); end
endVerboseOutput endVerboseOutput отображает конец подробных выходных данных во время обучения.
function endVerboseOutput(params) if params.Verbose disp("|=======================================================================================================================================================================|") end end
updateProgressPlot updateProgressPlot функция обновляет график хода выполнения с информацией о потерях и точности во время обучения.
function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation) if params.ProgressPlot % Update the training progress. D = duration(0,0,toc(start),"Format","hh:mm:ss"); title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D)); addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain)))); addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain); addpoints(plotters.ValAccPlotter,iteration,accuracyValidation); drawnow end end
[1] Каррейра, Жоао и Эндрю Зиссерман. "Кво Вадис, признание действий? Новая модель и набор данных по кинетике. " Материалы Конференции IEEE по компьютерному зрению и распознаванию образов (CVPR): 6299?? 6308. Гонолулу, HI: IEEE, 2017.
[2] Симоньян, Карен и Эндрю Зиссерман. «Двухстримовые сверточные сети для распознавания действий в видео». Достижения в системах обработки нейронной информации 27, Лонг-Бич, Калифорния: NIPS, 2017.
[3] Лошчилов, Илья и Фрэнк Хаттер. «SGDR: Стохастический градиентный спуск с теплыми перезапусками». Международная конференция по учебным представлениям 2017. Тулон, Франция: ICLR, 2017.