В этом примере показано, как обучить двухпоточную сверточную нейронную сеть 3-D (I3D) для распознавания активности с помощью RGB и данных оптического потока из видео [1].
Распознавание активности на основе зрения включает предсказание действия объекта, такого как ходьба, плавание или сидение, с помощью набора видеокадров. Распознавание активности из видео имеет много приложений, таких как взаимодействие человека с компьютером, обучение робота, обнаружение аномалий, наблюдение и обнаружение объектов. Для примера онлайн- предсказания нескольких действий для входящих видео с нескольких камер могут быть важны для обучения роботов. По сравнению с классификацией изображений распознавание действий с использованием видео трудно смоделировать из-за шумных меток в наборах данных видео, разнообразия действий, которые могут выполнять актёры в видео, которые являются сильно несбалансированными по классам, и вычислительной неэффективности предварительного обучения на больших наборах данных видео. Некоторые методы глубокого обучения, такие как I3D двухпотоковых сверточных сетей [1], показали улучшенную производительность, используя предварительное обучение на больших наборах данных классификации изображений.
Этот пример обучает I3D сеть, используя HMDB51 набор данных. Используйте downloadHMDB51
вспомогательная функция, перечисленная в конце этого примера, для загрузки HMDB51 набора данных в папку с именем hmdb51
.
downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);
После завершения загрузки извлеките файл RAR hmdb51_org.rar
на hmdb51
папка. Далее используйте checkForHMDB51Folder
вспомогательная функция, перечисленная в конце этого примера, чтобы подтвердить, что загруженные и извлеченные файлы на месте.
allClasses = checkForHMDB51Folder(downloadFolder);
Набор данных содержит около 2 ГБ видеоданных для 7000 клипов в 51 классе, таких как напиток, беги и пожимай руки. Каждый видеокадр имеет высоту 240 пикселей и минимальную ширину 176 пикселей. Количество систем координат варьируется от 18 до приблизительно 1000.
Чтобы уменьшить время обучения, этот пример обучает сеть распознавания активности классифицировать 5 классов действий вместо всех 51 классов в наборе данных. Задайте useAllData
на true
для обучения со всеми 51 классами.
useAllData = false; if useAllData classes = allClasses; else classes = ["kiss","laugh","pick","pour","pushup"]; end dataFolder = fullfile(downloadFolder, "hmdb51_org");
Разделите набор данных на набор обучающих данных для обучения сети и тестовый набор для оценки сети. Используйте 80% данных для набора обучающих данных, а остальное для тестового набора. Использование imageDatastore
разделить данные, основанные на каждой метке, на наборы обучающих и тестовых данных путем случайного выбора части файлов из каждой метки.
imds = imageDatastore(fullfile(dataFolder,classes),... 'IncludeSubfolders', true,... 'LabelSource', 'foldernames',... 'FileExtensions', '.avi'); [trainImds,testImds] = splitEachLabel(imds,0.8,'randomized'); trainFilenames = trainImds.Files; testFilenames = testImds.Files;
Для нормализации входных данных для сети в файле MAT приведены минимальное и максимальное значения для набора данных inputStatistics.mat
, прилагаемый к этому примеру. Чтобы найти минимальное и максимальное значения для другого набора данных, используйте inputStatistics
вспомогательная функция, перечисленная в конце этого примера.
inputStatsFilename = 'inputStatistics.mat'; if ~exist(inputStatsFilename, 'file') disp("Reading all the training data for input statistics...") inputStats = inputStatistics(dataFolder); else d = load(inputStatsFilename); inputStats = d.inputStats; end
Создайте два FileDatastore
объекты для обучения и валидации при помощи createFileDatastore
вспомогательная функция, заданная в конце этого примера. Каждый datastore считывает видео- файл, чтобы предоставить данные RGB, данные оптического потока и соответствующую информацию о метке.
Укажите количество систем координат для каждого чтения datastore. Типичными значениями являются 16, 32, 64 или 128. Использование большего количества систем координат помогает захватывать больше временной информации, но требует большей памяти для обучения и предсказания. Установите количество систем координат равное 64, чтобы сбалансировать использование памяти с эффективностью. Может потребоваться снизить это значение в зависимости от системных ресурсов.
numFrames = 64;
Задайте высоту и ширину систем координат для считывания datastore. Фиксация высоты и ширины к одному и тому же значению облегчает пакетную обработку данных для сети. Типичными значениями являются [112, 112], [224, 224] и [256, 256]. Минимальные высота и ширина видеокадров в HMDB51 наборе данных составляют 240 и 176, соответственно. Определите [112, 112], чтобы захватить большее число систем координат за счет пространственной информации. Если вы хотите задать формат кадра для чтения datastore, который больше минимальных значений, таких как [256, 256], сначала измените размер систем координат с помощью imesize.
frameSize = [112,112];
Задайте inputSize
на inputStats
структура, таким образом, функция read от fileDatastore
можно считать заданный размер входного сигнала.
inputSize = [frameSize, numFrames]; inputStats.inputSize = inputSize; inputStats.Classes = classes;
Создайте два FileDatastore
объекты, один для обучения и другой для валидации.
isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);
isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);
disp("Training data size: " + string(numel(dsTrain.Files)))
Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))
Validation data size: 109
Использование 3-D CNN является естественным подходом к извлечению пространственно-временных функций из видео. Можно создать I3D сеть из предварительно обученной сети классификации 2-D изображений, такой как Inception v1 или ResNet-50, путем расширения 2-D фильтров и объединения ядер в 3-D. Эта процедура повторно использует веса, полученные из задачи классификации изображений, чтобы загрузить задачу распознавания видео.
Следующий рисунок является выборкой, показывающим, как надуть слой свертки 2-D на слой свертки 3-D. Инфляция включает расширение размера, весов и смещения фильтра путем добавления третьего измерения (временной размерности).
Видео данных может быть рассмотрено как имеющее две части: пространственный компонент и временный компонент.
Пространственный компонент содержит информацию о форме, текстуре и цвете объектов в видео. Данные RGB содержат эту информацию.
Временной компонент содержит информацию о движении объектов через системы координат и изображает важные перемещения между камерой и объектами в сцене. Вычисление оптического потока является общим методом для извлечения временной информации из видео.
Двухпотоковый CNN включает пространственную подсеть и временную подсеть [2]. Сверточная нейронная сеть, обученная на плотном оптическом потоке и потоке видеоданных, может достичь лучшей эффективности с ограниченными обучающими данными, чем с необработанными сложенными системами координат RGB. Следующий рисунок показывает типовую двухпоточную I3D сеть.
В этом примере вы создаете I3D сеть с помощью GoogLeNet, сети, предварительно обученной в базе данных ImageNet.
Укажите количество каналов следующим 3
для подсети RGB и 2
для подсети оптического потока. Два канала для данных оптического потока являются и составляющие скорости, и , соответственно.
rgbChannels = 3; flowChannels = 2;
Получите минимальное и максимальное значения для RGB и данных оптического потока от inputStats
структура, загруженная из inputStatistics.mat
файл. Эти значения необходимы для image3dInputLayer
сетей I3D для нормализации входных данных.
rgbInputSize = [frameSize, numFrames, rgbChannels]; flowInputSize = [frameSize, numFrames, flowChannels]; rgbMin = inputStats.rgbMin; rgbMax = inputStats.rgbMax; oflowMin = inputStats.oflowMin(:,:,1:2); oflowMax = inputStats.oflowMax(:,:,1:2); rgbMin = reshape(rgbMin,[1,size(rgbMin)]); rgbMax = reshape(rgbMax,[1,size(rgbMax)]); oflowMin = reshape(oflowMin,[1,size(oflowMin)]); oflowMax = reshape(oflowMax,[1,size(oflowMax)]);
Укажите количество классов для обучения сети.
numClasses = numel(classes);
Создайте I3D подсети RGB и оптического потока при помощи Inflated3D
вспомогательная функция, которая присоединена к этому примеру. Подсети создаются из GoogLeNet.
cnnNet = googlenet; netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet); netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);
Создайте dlnetwork
объект из графика слоев каждой из I3D сетей.
dlnetRGB = dlnetwork(netRGB); dlnetFlow = dlnetwork(netFlow);
Создайте вспомогательную функцию modelGradients
, перечисленный в конце этого примера. The modelGradients
функция принимает за вход подсеть RGB dlnetRGB
, подсеть оптического потока dlnetFlow
мини-пакет входных данных dlRGB
и dlFlow
и мини-пакет основных истин данных о метках dlY
. Функция возвращает значение потерь обучения, градиенты потерь относительно настраиваемых параметров соответствующих подсетей и мини-пакетную точность подсетей.
Потеря вычисляется путем вычисления среднего значения потерь перекрестной энтропии предсказаний от каждой из подсетей. Выходы предсказания сети являются вероятностями от 0 до 1 для каждого из классов.
Точность каждой из подсетей вычисляется путем взятия среднего значения RGB и оптического потока предсказаний и сравнения его с основной истиной меткой входов.
Обучите с мини-пакетом размером 20 для 1500 итераций. Задайте итерацию, после которой можно сохранить модель с лучшей точностью валидации при помощи SaveBestAfterIteration
параметр.
Задайте параметры расписания скорости обучения с отжигом косинуса [3]. Для обеих сетей используйте:
Минимальная скорость обучения 1e-4.
Максимальная скорость обучения 1e-3.
Количество косинусов итераций 300, 500 и 700, после чего цикл расписания скорости обучения перезапускается. Опция CosineNumIterations
задает ширину каждого цикла косинуса.
Задайте параметры для оптимизации SGDM. Инициализируйте параметры оптимизации SGDM в начале обучения для каждой из сетей RGB и оптического потока. Для обеих сетей используйте:
Импульс 0,9.
Начальный параметр скорости инициализирован как []
.
Коэффициент регуляризации L2 0,0005.
Задайте, чтобы отправлять данные в фоновом режиме с помощью параллельного пула. Если DispatchInBackground
установлено значение true, откройте параллельный пул с заданным количеством параллельных рабочих мест и создайте DispatchInBackgroundDatastore
, представленный в качестве части этого примера, который отправляет данные в фоновом режиме, чтобы ускорить обучение с использованием асинхронной загрузки данных и предварительной обработки. По умолчанию этот пример использует графический процессор, если он доступен. В противном случае используется центральный процессор. Для использования GPU требуется Parallel Computing Toolbox™ и графический процессор с поддержкой CUDA ® NVIDIA ®. Для получения информации о поддерживаемых вычислительных возможностях смотрите Поддержку GPU by Release (Parallel Computing Toolbox).
params.Classes = classes; params.MiniBatchSize = 20; params.NumIterations = 1500; params.SaveBestAfterIteration = 900; params.CosineNumIterations = [300, 500, 700]; params.MinLearningRate = 1e-4; params.MaxLearningRate = 1e-3; params.Momentum = 0.9; params.VelocityRGB = []; params.VelocityFlow = []; params.L2Regularization = 0.0005; params.ProgressPlot = false; params.Verbose = true; params.ValidationData = dsVal; params.DispatchInBackground = false; params.NumWorkers = 4;
Обучите подсети с помощью данных RGB и данных оптического потока. Установите doTraining
переменная в false
загрузить предварительно обученные подсети, не дожидаясь завершения обучения. Кроме того, если необходимо обучить подсети, установите doTraining
переменная в true
.
doTraining = false;
Для каждой эпохи:
Перетащите данные перед закольцовыванием по мини-пакетам данных.
Использование minibatchqueue
для закольцовывания мини-пакетов. Вспомогательная функция createMiniBatchQueue
, перечисленный в конце этого примера, использует данный обучающий datastore для создания minibatchqueue
.
Используйте данные валидации dsVal
для проверки сетей.
Отобразите результаты потерь и точности для каждой эпохи с помощью вспомогательной функции displayVerboseOutputEveryEpoch
, перечисленный в конце этого примера.
Для каждого мини-пакета:
Преобразуйте данные изображения или данные оптического потока и метки в dlarray
объекты с базовым типом single.
Обработайте временную размерность видео и данных оптического потока как одну из пространственных размерностей, чтобы позволить обработку с помощью 3-D CNN. Задайте метки размерности "SSSCB"
(пространственный, пространственный, пространственный, канал, пакет) для RGB или данных оптического потока и "CB"
для данных о метке.
The minibatchqueue
объект использует вспомогательную функцию batchRGBAndFlow
, перечисленный в конце этого примера, чтобы упаковать данные RGB и оптического потока.
modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat"; if doTraining epoch = 1; bestValAccuracy = 0; accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; iteration = 1; shuffled = shuffleTrainDs(dsTrain); % Number of outputs is three: One for RGB frames, one for optical flow % data, and one for ground truth labels. numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); start = tic; trainTime = start; % Use the initializeTrainingProgressPlot and initializeVerboseOutput % supporting functions, listed at the end of the example, to initialize % the training progress plot and verbose output to display the training % loss, training accuracy, and validation accuracy. plotters = initializeTrainingProgressPlot(params); initializeVerboseOutput(params); while iteration <= params.NumIterations % Iterate through the data set. [dlX1,dlX2,dlY] = next(mbq); % Evaluate the model gradients and loss using dlfeval. [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ... dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); % Accumulate the loss and accuracies. lossTrain = [lossTrain, loss]; accTrain = [accTrain, acc]; accTrainRGB = [accTrainRGB, accRGB]; accTrainFlow = [accTrainFlow, accFlow]; % Update the network state. dlnetRGB.State = stateRGB; dlnetFlow.State = stateFlow; % Update the gradients and parameters for the RGB and optical flow % subnetworks using the SGDM optimizer. [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ... updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration); [dlnetFlow,gradFlow,params.VelocityFlow] = ... updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration); if ~hasdata(mbq) || iteration == params.NumIterations % Current epoch is complete. Do validation and update progress. trainTime = toc(trainTime); [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ... doValidation(params, dlnetRGB, dlnetFlow); % Update the training progress. displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),... accValidation,accValidationRGB,accValidationFlow,... mean(lossTrain),lossValidation,trainTime,validationTime); updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation); % Save model with the trained dlnetwork and accuracy values. % Use the saveData supporting function, listed at the % end of this example. if iteration >= params.SaveBestAfterIteration if accValidation > bestValAccuracy bestValAccuracy = accValidation; saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation); end end end if ~hasdata(mbq) && iteration < params.NumIterations % Current epoch is complete. Initialize the training loss, accuracy % values, and minibatchqueue for the next epoch. accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; trainTime = tic; epoch = epoch + 1; shuffled = shuffleTrainDs(dsTrain); numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); end iteration = iteration + 1; end % Display a message when training is complete. endVerboseOutput(params); disp("Model saved to: " + modelFilename); end % Download the pretrained model and video file for prediction. filename = "activityRecognition-I3D-HMDB51.zip"; downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename; filename = fullfile(downloadFolder,filename); if ~exist(filename,'file') disp('Downloading the pretrained network...'); websave(filename,downloadURL); end % Unzip the contents to the download folder. unzip(filename,downloadFolder); if ~doTraining modelFilename = fullfile(downloadFolder, modelFilename); end
Используйте набор тестовых данных для оценки точности обученных подсетей.
Загрузите лучшую модель, сохраненную во время обучения.
d = load(modelFilename); dlnetRGB = d.data.dlnetRGB; dlnetFlow = d.data.dlnetFlow;
Создайте minibatchqueue
объект для загрузки пакетов тестовых данных.
numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);
Для каждого пакета тестовых данных сделайте предсказания с помощью сетей RGB и оптического потока, возьмите среднее значение предсказаний и вычислите точность предсказания с помощью матрицы неточностей.
cmat = sparse(numClasses,numClasses); while hasdata(mbq) [dlRGB, dlFlow, dlY] = next(mbq); % Pass the video input as RGB and optical flow data through the % two-stream subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(dlY,[],1); [~,YPred] = max(dlYPred,[],1); cmat = aggregateConfusionMetric(cmat,YTest,YPred); end
Вычислите среднюю точность классификации для обученных сетей.
accuracyEval = sum(diag(cmat))./sum(cmat,"all")
accuracyEval = 0.60909
Отобразите матрицу неточностей.
figure chart = confusionchart(cmat,classes);
Из-за ограниченного количества обучающих выборок увеличение точности сверх 61% является сложным. Для повышения робастности сети требуется дополнительное обучение с большим набором данных. В сложение предварительное обучение на большем наборе данных, таком как Kinetics [1], может помочь улучшить результаты.
Теперь можно использовать обученные сети для предсказания действий в новых видео. Чтение и отображение видео pour.avi
использование VideoReader
и vision.VideoPlayer
.
videoFilename = fullfile(downloadFolder, "pour.avi"); videoReader = VideoReader(videoFilename); videoPlayer = vision.VideoPlayer; videoPlayer.Name = "pour"; while hasFrame(videoReader) frame = readFrame(videoReader); step(videoPlayer,frame); end release(videoPlayer);
Используйте readRGBAndFlow
вспомогательная функция, перечисленная в конце этого примера, для чтения данных RGB и оптического потока.
isDataForValidation = true; readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
Функция read возвращает логическое isDone
значение, которое указывает, существует ли больше данных для чтения из файла. Используйте batchRGBAndFlow
вспомогательная функция, заданная в конце этого примера, чтобы упаковать данные, чтобы пройти через двухпотоковые подсети, чтобы получить предсказания.
hasdata = true; userdata = []; YPred = []; while hasdata [data,userdata,isDone] = readFcn(videoFilename,userdata); [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3)); % Pass video input as RGB and optical flow data through the two-stream % subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; [~,YPredCurr] = max(dlYPred,[],1); YPred = horzcat(YPred,YPredCurr); hasdata = ~isDone; end YPred = extractdata(YPred);
Подсчитайте количество правильных предсказаний, используя histcounts
, и получить предсказанное действие, используя максимальное количество правильных предсказаний.
classes = params.Classes; counts = histcounts(YPred,1:numel(classes)); [~,clsIdx] = max(counts); action = classes(clsIdx)
action = "pour"
inputStatistics
The inputStatistics
функция принимает за вход имя папки, содержащей данные HMDB51, и вычисляет минимальное и максимальное значения для данных RGB и данных оптического потока. Минимальное и максимальное значения используются как нормализация, входы к входу уровню сетей. Эта функция также получает количество систем координат в каждом из файлов видео, которые используются позже во время обучения и проверки сети. В порядок, чтобы найти минимальное и максимальное значения для другого набора данных, используйте эту функцию с именем папки, содержащей набор данных.
function inputStats = inputStatistics(dataFolder) ds = createDatastore(dataFolder); ds.ReadFcn = @getMinMax; tic; tt = tall(ds); varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'}; stats = gather(groupsummary(tt,[],{'max','min'}, varnames)); inputStats.Filename = gather(tt.Filename); inputStats.NumFrames = gather(tt.NumFrames); inputStats.rgbMax = stats.max_rgbMax; inputStats.rgbMin = stats.min_rgbMin; inputStats.oflowMax = stats.max_oflowMax; inputStats.oflowMin = stats.min_oflowMin; save('inputStatistics.mat','inputStats'); toc; end function data = getMinMax(filename) reader = VideoReader(filename); opticFlow = opticalFlowFarneback; data = []; while hasFrame(reader) frame = readFrame(reader); [rgb,oflow] = findMinMax(frame,opticFlow); data = assignMinMax(data, rgb, oflow); end totalFrames = floor(reader.Duration * reader.FrameRate); totalFrames = min(totalFrames, reader.NumFrames); [labelName, filename] = getLabelFilename(filename); data.Filename = fullfile(labelName, filename); data.NumFrames = totalFrames; data = struct2table(data,'AsArray',true); end function data = assignMinMax(data, rgb, oflow) if isempty(data) data.rgbMax = rgb.Max; data.rgbMin = rgb.Min; data.oflowMax = oflow.Max; data.oflowMin = oflow.Min; return; end data.rgbMax = max(data.rgbMax, rgb.Max); data.rgbMin = min(data.rgbMin, rgb.Min); data.oflowMax = max(data.oflowMax, oflow.Max); data.oflowMin = min(data.oflowMin, oflow.Min); end function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow) rgbMinMax.Max = max(rgb,[],[1,2]); rgbMinMax.Min = min(rgb,[],[1,2]); gray = rgb2gray(rgb); flow = estimateFlow(opticFlow,gray); oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude); oflowMinMax.Max = max(oflow,[],[1,2]); oflowMinMax.Min = min(oflow,[],[1,2]); end function ds = createDatastore(folder) ds = fileDatastore(folder,... 'IncludeSubfolders', true,... 'FileExtensions', '.avi',... 'UniformRead', true,... 'ReadFcn', @getMinMax); disp("NumFiles: " + numel(ds.Files)); end
createFileDatastore
The createFileDatastore
функция создает FileDatastore
объект с указанными именами файлов. The FileDatastore
объект считывает данные в 'partialfile'
режим, поэтому каждое чтение может вернуть частично считанные системы координат из видео. Эта функция помогает с чтением больших видео файлов, если все системы координат не помещаются в памяти.
function datastore = createFileDatastore(filenames,inputStats,isDataForValidation) readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation); datastore = fileDatastore(filenames,... 'ReadFcn',readFcn,... 'ReadMode','partialfile'); end
readRGBAndFlow
The readRGBAndFlow
функция считывает системы координат RGB, соответствующие данные оптического потока и значения меток для заданного видеосигнала файла. Во время обучения функция read считывает определенное количество систем координат согласно размеру входа сети с случайным выбором стартовой системы координат. Данные оптического потока вычисляются с начала файла видео, но пропускаются до достижения стартовой системы координат. Во время проверки все системы координат последовательно считываются, и вычисляются соответствующие данные оптического потока. Системы координат RGB и данные оптического потока случайным образом обрезаются до необходимого размера входа сети для обучения и обрезаются по центру для проверки и валидации.
function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation) if isempty(userdata) userdata.reader = VideoReader(filename); userdata.batchesRead = 0; userdata.opticalFlow = opticalFlowFarneback; [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename); if isempty(totalFrames) totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate); totalFrames = min(totalFrames, userdata.reader.NumFrames); end userdata.totalFrames = totalFrames; end reader = userdata.reader; totalFrames = userdata.totalFrames; label = userdata.label; batchesRead = userdata.batchesRead; opticalFlow = userdata.opticalFlow; inputSize = inputStats.inputSize; H = inputSize(1); W = inputSize(2); rgbC = 3; flowC = 2; numFrames = inputSize(3); if numFrames > totalFrames numBatches = 1; else numBatches = floor(totalFrames/numFrames); end imH = userdata.reader.Height; imW = userdata.reader.Width; imsz = [imH,imW]; if ~isDataForValidation augmentFcn = augmentTransform([imsz,3]); cropWindow = randomCropWindow2d(imsz, inputSize(1:2)); % 1. Randomly select required number of frames, % starting randomly at a specific frame. if numFrames >= totalFrames idx = 1:totalFrames; % Add more frames to fill in the network input size. additional = ceil(numFrames/totalFrames); idx = repmat(idx,1,additional); idx = idx(1:numFrames); else startIdx = randperm(totalFrames - numFrames); startIdx = startIdx(1); endIdx = startIdx + numFrames - 1; idx = startIdx:endIdx; end video = zeros(H,W,rgbC,numFrames); oflow = zeros(H,W,flowC,numFrames); i = 1; % Discard the first set of frames to initialize the optical flow. for ii = 1:idx(1)-1 frame = read(reader,ii); getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); end % Read the next set of required number of frames for training. for ii = idx frame = read(reader,ii); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end else augmentFcn = @(data)(data); cropWindow = centerCropWindow2d(imsz, inputSize(1:2)); toRead = min([numFrames,totalFrames]); video = zeros(H,W,rgbC,toRead); oflow = zeros(H,W,flowC,toRead); i = 1; while hasFrame(reader) && i <= numFrames frame = readFrame(reader); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end if numFrames > totalFrames additional = ceil(numFrames/totalFrames); video = repmat(video,1,1,1,additional); oflow = repmat(oflow,1,1,1,additional); video = video(:,:,:,1:numFrames); oflow = oflow(:,:,:,1:numFrames); end end % The network expects the video and optical flow input in % the following dlarray format: % "SSSCB" ==> Height x Width x Frames x Channels x Batch % % Permute the data % from % Height x Width x Channels x Frames % to % Height x Width x Frames x Channels video = permute(video, [1,2,4,3]); oflow = permute(oflow, [1,2,4,3]); data = {video, oflow, label}; batchesRead = batchesRead + 1; userdata.batchesRead = batchesRead; % Set the done flag to true, if the reader has read all the frames or % if it is training. done = batchesRead == numBatches || ~isDataForValidation; end function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow) rgb = augmentFcn(rgb); gray = rgb2gray(rgb); flow = estimateFlow(opticalFlow,gray); vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy); rgb = imcrop(rgb, cropWindow); vxvy = imcrop(vxvy, cropWindow); vxvy = vxvy(:,:,1:2); end function [label,fname] = getLabelFilename(filename) [folder,name,ext] = fileparts(string(filename)); [~,label] = fileparts(folder); fname = name + ext; label = string(label); fname = string(fname); end function [totalFrames,label] = getTotalFramesAndLabel(info, filename) filenames = info.Filename; frames = info.NumFrames; [labelName, fname] = getLabelFilename(filename); idx = strcmp(filenames, fullfile(labelName,fname)); totalFrames = frames(idx); label = categorical(string(labelName), string(info.Classes)); end
augmentTransform
The augmentTransform
функция создает метод увеличения со случайными лево-правыми коэффициентами поворота и масштабирования.
function augmentFcn = augmentTransform(sz) % Randomly flip and scale the image. tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput'); augmentFcn = @(data)augmentData(data,tform,rout); function data = augmentData(data,tform,rout) data = imwarp(data,tform,'OutputView',rout); end end
modelGradients
The modelGradients
функция принимает за вход мини-пакет данных RGB dlRGB
соответствующие данные оптического потока dlFlow
, и соответствующий целевой dlY
и возвращает соответствующие потери, градиенты потерь относительно настраиваемых параметров и точность обучения. Чтобы вычислить градиенты, вычислите modelGradients
функция, использующая dlfeval
функция в цикле обучения.
function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass video input as RGB and optical flow data through the two-stream % network. [dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB); [dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow); % Calculate fused loss, gradients, and accuracy for the two-stream % predictions. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); gradientsRGB = dlgradient(loss,dlnetRGB.Learnables); gradientsFlow = dlgradient(loss,dlnetFlow.Learnables); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); acc = gather(extractdata(sum(YTest == YPred)./numel(YTest))); % Calculate the accuracy of the RGB and flow predictions. [~,YTest] = max(Y,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest))); accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest))); end
doValidation
The doValidation
Функция проверяет сеть с помощью данных валидации.
function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow) validationTime = tic; numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params); lossValidation = []; numClasses = numel(params.Classes); cmat = sparse(numClasses,numClasses); cmatRGB = sparse(numClasses,numClasses); cmatFlow = sparse(numClasses,numClasses); while hasdata(mbq) [dlX1,dlX2,dlY] = next(mbq); [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); lossValidation = [lossValidation,loss]; cmat = aggregateConfusionMetric(cmat,YTest,YPred); cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB); cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow); end lossValidation = mean(lossValidation); accValidation = sum(diag(cmat))./sum(cmat,"all"); accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all"); accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all"); validationTime = toc(validationTime); end
predictValidation
The predictValidation
функция вычисляет значения потерь и предсказаний с помощью предоставленной dlnetwork
объекты для RGB и данных оптического потока.
function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass the video input through the two-stream % network. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Calculate the cross-entropy separately for the two-stream % outputs. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); end
updateDlnetwork
The updateDlnetwork
функция обновляет предоставленные dlnetwork
объект с градиентами и другими параметрами, использующими оптимизационную функцию SGDM sgdmupdate
.
function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration) % Determine the learning rate using the cosine-annealing learning rate schedule. learnRate = cosineAnnealingLearnRate(iteration, params); % Apply L2 regularization to the weights. idx = dlnet.Learnables.Parameter == "Weights"; gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:)); % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum); end
cosineAnnealingLearnRate
The cosineAnnealingLearnRate
функция вычисляет скорость обучения на основе текущего числа итераций, минимальной скорости обучения, максимальной скорости обучения и количества итераций для отжига [3].
function lr = cosineAnnealingLearnRate(iteration, params) if iteration == params.NumIterations lr = params.MinLearningRate; return; end cosineNumIter = [0, params.CosineNumIterations]; csum = cumsum(cosineNumIter); block = find(csum >= iteration, 1,'first'); cosineIter = iteration - csum(block - 1); annealingIteration = mod(cosineIter, cosineNumIter(block)); cosineIteration = cosineNumIter(block); minR = params.MinLearningRate; maxR = params.MaxLearningRate; cosMult = 1 + cos(pi * annealingIteration / cosineIteration); lr = minR + ((maxR - minR) * cosMult / 2); end
aggregateConfusionMetric
The aggregateConfusionMetric
функция пошагово заполняет матрицу неточностей на основе предсказанных результатов YPred
и ожидаемые результаты YTest
.
function cmat = aggregateConfusionMetric(cmat,YTest,YPred) YTest = gather(extractdata(YTest)); YPred = gather(extractdata(YPred)); [m,n] = size(cmat); cmat = cmat + full(sparse(YTest,YPred,1,m,n)); end
createMiniBatchQueue
The createMiniBatchQueue
функция создает minibatchqueue
объект, который обеспечивает miniBatchSize
объем данных из данного datastore. Это также создает DispatchInBackgroundDatastore
если параллельный пул открыт.
function mbq = createMiniBatchQueue(datastore, numOutputs, params) if params.DispatchInBackground && isempty(gcp('nocreate')) % Start a parallel pool, if DispatchInBackground is true, to dispatch % data in the background using the parallel pool. c = parcluster('local'); c.NumWorkers = params.NumWorkers; parpool('local',params.NumWorkers); end p = gcp('nocreate'); if ~isempty(p) datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers); end inputFormat(1:numOutputs-1) = "SSSCB"; outputFormat = "CB"; mbq = minibatchqueue(datastore, numOutputs, ... "MiniBatchSize", params.MiniBatchSize, ... "MiniBatchFcn", @batchRGBAndFlow, ... "MiniBatchFormat", [inputFormat,outputFormat]); end
batchRGBAndFlow
The batchRGBAndFlow
функция пакетирует изображение, поток и данные о метках в соответствующие dlarray
значения в форматах данных "SSSCB"
, "SSSCB"
, и "CB"
, соответственно.
function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels) % Batch dimension: 5 X1 = cat(5,images{:}); X2 = cat(5,flows{:}); % Batch dimension: 2 labels = cat(2,labels{:}); % Feature dimension: 1 Y = onehotencode(labels,1); % Cast data to single for processing. X1 = single(X1); X2 = single(X2); Y = single(Y); % Move data to the GPU if possible. if canUseGPU X1 = gpuArray(X1); X2 = gpuArray(X2); Y = gpuArray(Y); end % Return X and Y as dlarray objects. dlX1 = dlarray(X1,"SSSCB"); dlX2 = dlarray(X2,"SSSCB"); dlY = dlarray(Y,"CB"); end
shuffleTrainDs
The shuffleTrainDs
функция тасует файлы, существующие в обучающем datastore dsTrain
.
function shuffled = shuffleTrainDs(dsTrain) shuffled = copy(dsTrain); n = numel(shuffled.Files); shuffledIndices = randperm(n); shuffled.Files = shuffled.Files(shuffledIndices); reset(shuffled); end
saveData
The saveData
функция сохраняет заданную dlnetwork
объекты и значения точности в MAT файла.
function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation) dlnetRGB = gatherFromGPUToSave(dlnetRGB); dlnetFlow = gatherFromGPUToSave(dlnetFlow); data.ValidationAccuracy = accValidation; data.cmat = cmat; data.dlnetRGB = dlnetRGB; data.dlnetFlow = dlnetFlow; save(modelFilename, 'data'); end
gatherFromGPUToSave
The gatherFromGPUToSave
функция собирает данные из графического процессора в порядок, чтобы сохранить модель на диск.
function dlnet = gatherFromGPUToSave(dlnet) if ~canUseGPU return; end dlnet.Learnables = gatherValues(dlnet.Learnables); dlnet.State = gatherValues(dlnet.State); function tbl = gatherValues(tbl) for ii = 1:height(tbl) tbl.Value{ii} = gather(tbl.Value{ii}); end end end
checkForHMDB51Folder
The checkForHMDB51Folder
функция проверяет загруженные данные в папке загрузки.
function classes = checkForHMDB51Folder(dataLoc) hmdbFolder = fullfile(dataLoc, "hmdb51_org"); if ~exist(hmdbFolder, "dir") error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",... "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",... "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",... "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",... "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",... "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",... "sword_exercise","talk","throw","turn","walk","wave"]; expectFolders = fullfile(hmdbFolder, classes); if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders)) error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end end
downloadHMDB51
The downloadHMDB51
функция загружает набор данных и сохраняет его в директорию.
function downloadHMDB51(dataLoc) if nargin == 0 dataLoc = pwd; end dataLoc = string(dataLoc); if ~exist(dataLoc,"dir") mkdir(dataLoc); end dataUrl = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"; options = weboptions('Timeout', Inf); rarFileName = fullfile(dataLoc, 'hmdb51_org.rar'); fileExists = exist(rarFileName, 'file'); % Download the RAR file and save it to the download folder. if ~fileExists disp("Downloading hmdb51_org.rar (2 GB) to the folder:") disp(dataLoc) disp("This download can take a few minutes...") websave(rarFileName, dataUrl, options); disp("Download complete.") disp("Extract the hmdb51_org.rar file contents to the folder: ") disp(dataLoc) end end
initializeTrainingProgressPlot
The initializeTrainingProgressPlot
функция конфигурирует два графика для отображения потерь обучения, точности обучения и точности валидации.
function plotters = initializeTrainingProgressPlot(params) if params.ProgressPlot % Plot the loss, training accuracy, and validation accuracy. figure % Loss plot subplot(2,1,1) plotters.LossPlotter = animatedline; xlabel("Iteration") ylabel("Loss") % Accuracy plot subplot(2,1,2) plotters.TrainAccPlotter = animatedline('Color','b'); plotters.ValAccPlotter = animatedline('Color','g'); legend('Training Accuracy','Validation Accuracy','Location','northwest'); xlabel("Iteration") ylabel("Accuracy") else plotters = []; end end
initializeVerboseOutput
The initializeVerboseOutput
Функция отображает заголовки столбца для таблицы обучающих значений, которая показывает эпоху, точность мини-пакетов и другие значения обучения.
function initializeVerboseOutput(params) if params.Verbose disp(" ") if canUseGPU disp("Training on GPU.") else disp("Training on CPU.") end p = gcp('nocreate'); if ~isempty(p) disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ") end disp("NumIterations:" + string(params.NumIterations)); disp("MiniBatchSize:" + string(params.MiniBatchSize)); disp("Classes:" + join(string(params.Classes), ",")); disp("|=======================================================================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-Batch Accuracy | Validation Accuracy | Mini-Batch | Validation | Base Learning | Train Time | Validation Time |") disp("| | | (hh:mm:ss) | (Avg:RGB:Flow) | (Avg:RGB:Flow) | Loss | Loss | Rate | (hh:mm:ss) | (hh:mm:ss) |") disp("|=======================================================================================================================================================================|") end end
displayVerboseOutputEveryEpoch
The displayVerboseOutputEveryEpoch
функция отображает подробные выходы обучающих значений, таких как эпоха, точность мини-пакета, точность валидации и потеря мини-пакета.
function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime) if params.Verbose D = duration(0,0,toc(start),'Format','hh:mm:ss'); trainTime = duration(0,0,trainTime,'Format','hh:mm:ss'); validationTime = duration(0,0,validationTime,'Format','hh:mm:ss'); lossValidation = gather(extractdata(lossValidation)); lossValidation = compose('%.4f',lossValidation); accValidation = composePadAccuracy(accValidation); accValidationRGB = composePadAccuracy(accValidationRGB); accValidationFlow = composePadAccuracy(accValidationFlow); accVal = join([accValidation,accValidationRGB,accValidationFlow], " : "); lossTrain = gather(extractdata(lossTrain)); lossTrain = compose('%.4f',lossTrain); accTrain = composePadAccuracy(accTrain); accTrainRGB = composePadAccuracy(accTrainRGB); accTrainFlow = composePadAccuracy(accTrainFlow); accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : "); learnRate = compose('%.13f',learnRate); disp("| " + ... pad(string(epoch),5,'both') + " | " + ... pad(string(iteration),9,'both') + " | " + ... pad(string(D),12,'both') + " | " + ... pad(string(accTrain),26,'both') + " | " + ... pad(string(accVal),26,'both') + " | " + ... pad(string(lossTrain),10,'both') + " | " + ... pad(string(lossValidation),10,'both') + " | " + ... pad(string(learnRate),13,'both') + " | " + ... pad(string(trainTime),10,'both') + " | " + ... pad(string(validationTime),15,'both') + " |") end end function acc = composePadAccuracy(acc) acc = compose('%.2f',acc*100) + "%"; acc = pad(string(acc),6,'left'); end
endVerboseOutput
The endVerboseOutput
функция отображает конец подробного выхода во время обучения.
function endVerboseOutput(params) if params.Verbose disp("|=======================================================================================================================================================================|") end end
updateProgressPlot
The updateProgressPlot
функция обновляет график прогресса с информацией о потерях и точности во время обучения.
function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation) if params.ProgressPlot % Update the training progress. D = duration(0,0,toc(start),"Format","hh:mm:ss"); title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D)); addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain)))); addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain); addpoints(plotters.ValAccPlotter,iteration,accuracyValidation); drawnow end end
[1] Каррейра, Жоао и Эндрю Зиссерман. "Кво Вадис, признание действий? Новая модель и набор данных кинетики ". Материалы Конференции IEEE по компьютерному зрению и распознаванию шаблонов (CVPR): 6299?? 6308. Гонолулу, HI: IEEE, 2017.
[2] Симоньян, Карен и Эндрю Зиссерман. Двухпоточные сверточные сети для распознавания действий в видео. Усовершенствования в системах обработки нейронной информации 27, Лонг-Бич, Калифорния: NIPS, 2017.
[3] Лощилов, Илья и Фрэнк Хаттер. SGDR: Стохастический градиентный спуск с теплыми перезапусками. Международная конференция по учебным представлениям 2017. Тулон, Франция: ICLR, 2017.