Прогноз глубокого обучения с Intel MKL-DNN

Этот пример показывает, как использовать codegen, чтобы сгенерировать код для приложения классификации изображений, которое использует глубокое обучение на процессорах Intel®. Сгенерированный код использует в своих интересах Math Kernel Library Intel для Глубоких нейронных сетей (MKL-DNN). Во-первых, пример генерирует MEX-функцию, которая запускает прогноз при помощи сети классификации изображений ResNet-50. Затем пример создает статическую библиотеку и компилирует ее с основным файлом, который запускает прогноз с помощью сети классификации изображений ResNet-50.

Предпосылки

  • Процессор Xeon с поддержкой Intel Усовершенствованные Векторные Расширения 2 (Intel AVX2) инструкции

  • Intel Math Kernel Library для глубоких нейронных сетей (MKL-DNN)

  • Библиотека Компьютерного зрения С открытым исходным кодом (OpenCV) v3.1

  • Переменные окружения для Intel MKL-DNN и OpenCV

  • MATLAB® Coder™, для генерации Кода С++.

  • Интерфейс MATLAB Coder пакета поддержки для Глубокого обучения.

  • Deep Learning Toolbox™, для использования объекта DAGNetwork

  • Пакет Поддержки Модель Deep Learning Toolbox для пакета Сетевой поддержки ResNet-50, для использования предварительно обученной сети ResNet.

Для получения дополнительной информации смотрите Предпосылки для Глубокого обучения для MATLAB Coder.

Этот пример поддерживается на платформах Windows® и Linux®.

Функция resnet_predict

Этот пример использует сеть DAG ResNet-50, чтобы показать классификацию изображений с MKL-DNN. Предварительно обученная модель ResNet-50 для MATLAB доступна в пакете поддержки Модель Deep Learning Toolbox для Сети ResNet-50. Чтобы загрузить и установить пакет поддержки, используйте Add-On Explorer. Смотрите Получают Дополнения (MATLAB).

Функция resnet_predict загружает сеть ResNet-50 в персистентный сетевой объект. На последующих вызовах функции снова используется постоянный объект.

type resnet_predict
% Copyright 2018 The MathWorks, Inc.

function out = resnet_predict(in) 
%#codegen

% A persistent object mynet is used to load the series network object.
% At the first call to this function, the persistent object is constructed and
% setup. When the function is called subsequent times, the same object is reused 
% to call predict on inputs, avoiding reconstructing and reloading the
% network object.

persistent mynet;

if isempty(mynet)
    % Call the function resnet50 that returns a DAG network
    % for ResNet-50 model.
    mynet = coder.loadDeepLearningNetwork('resnet50','resnet');
end

% pass in input   
out = mynet.predict(in);

Сгенерируйте код MEX для функции resnet_predict

Чтобы сгенерировать MEX-функцию от функции resnet_predict.m, используйте codegen с объектом настройки глубокого обучения, созданным для библиотеки MKL-DNN. Присоедините объект настройки глубокого обучения к объекту настройки генерации кода MEX, что вы передаете codegen.

 cfg = coder.config('mex');
 cfg.TargetLang = 'C++';
 cfg.DeepLearningConfig = coder.DeepLearningConfig('mkldnn');
 codegen -config cfg resnet_predict -args {ones(224,224,3,'single')} -report
Code generation successful: To view the report, open('codegen\mex\resnet_predict\html\report.mldatx').

Вызовите predict на тестовом изображении

im = imread('peppers.png');
im = imresize(im, [224,224]);
imshow(im);
predict_scores = resnet_predict_mex(single(im));

Сопоставьте лучшие пять очков прогноза со словами в synset словаре.

fid = fopen('synsetWords.txt');
synsetOut = textscan(fid,'%s', 'delimiter', '\n');
synsetOut = synsetOut{1};
fclose(fid);
[val,indx] = sort(predict_scores, 'descend');
scores = val(1:5)*100;
top5labels = synsetOut(indx(1:5));

Отобразите лучшие пять меток классификации на изображении.

outputImage = zeros(224,400,3, 'uint8');
for k = 1:3
    outputImage(:,177:end,k) = im(:,:,k);
end

scol = 1;
srow = 1;
outputImage = insertText(outputImage, [scol, srow], 'Classification with ResNet-50', 'TextColor', 'w','FontSize',20, 'BoxColor', 'black');
srow = srow + 30;
for k = 1:5
    outputImage = insertText(outputImage, [scol, srow], [top5labels{k},' ',num2str(scores(k), '%2.2f'),'%'], 'TextColor', 'w','FontSize',15, 'BoxColor', 'black');
    srow = srow + 25;
end

imshow(outputImage);

Очистите статический сетевой объект из памяти.

clear mex;

Сгенерируйте статическую библиотеку для функции resnet_predict

Чтобы сгенерировать статическую библиотеку от функции resnet_predict.m, используйте codegen с объектом настройки глубокого обучения, созданным для библиотеки MKL-DNN. Присоедините объект настройки глубокого обучения к объекту настройки генерации кода, что вы передаете codegen.

cfg = coder.config('lib');
cfg.TargetLang = 'C++';
cfg.DeepLearningConfig = coder.DeepLearningConfig('mkldnn');
codegen -config cfg resnet_predict -args {ones(224,224,3,'single')} -report
%
codegendir = fullfile(pwd, 'codegen', 'lib', 'resnet_predict');
Code generation successful: To view the report, open('codegen\lib\resnet_predict\html\report.mldatx').

Файл main_resnet.cpp

Основной файл используется, чтобы сгенерировать исполняемый файл от статической библиотеки, созданной командой codegen. Основной файл читает входное изображение, прогноз выполнений на изображении, и отображает метки классификации на изображении.

type main_resnet.cpp
/* Copyright 2018 The MathWorks, Inc. */

#include "resnet_predict.h"

#include <stdio.h>
#include <string.h>
#include <math.h>
#include <iostream>
#include "opencv2/opencv.hpp"
using namespace cv;

int readData(void* inputBuffer, char* inputImage) {

    Mat inpImage, intermImage;
    inpImage = imread(inputImage, 1);
    Size size(224, 224);
    resize(inpImage, intermImage, size);
    if (!intermImage.data) {
        printf(" No image data \n ");
        exit(1);
    }
    float* input = (float*)inputBuffer;

    for (int j = 0; j < 224 * 224; j++) {
        // BGR to RGB
        input[2 * 224 * 224 + j] = (float)(intermImage.data[j * 3 + 0]);
        input[1 * 224 * 224 + j] = (float)(intermImage.data[j * 3 + 1]);
        input[0 * 224 * 224 + j] = (float)(intermImage.data[j * 3 + 2]);
    }
    return 1;
}

#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) || defined(_WIN64)

int cmpfunc(void* r, const void* a, const void* b) {
    float x = ((float*)r)[*(int*)b] - ((float*)r)[*(int*)a];
    return (x > 0 ? ceil(x) : floor(x));
}
#else

int cmpfunc(const void* a, const void* b, void* r) {
    float x = ((float*)r)[*(int*)b] - ((float*)r)[*(int*)a];
    return (x > 0 ? ceil(x) : floor(x));
}

#endif

void top(float* r, int* top5) {
    int t[1000];
    for (int i = 0; i < 1000; i++) {
        t[i] = i;
    }
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) || defined(_WIN64)
    qsort_s(t, 1000, sizeof(int), cmpfunc, r);
#else
    qsort_r(t, 1000, sizeof(int), cmpfunc, r);
#endif
    top5[0] = t[0];
    top5[1] = t[1];
    top5[2] = t[2];
    top5[3] = t[3];
    top5[4] = t[4];
    return;
}


int prepareSynset(char synsets[1000][100]) {
    FILE* fp1 = fopen("synsetWords.txt", "r");
    if (fp1 == 0) {
        return -1;
    }

    for (int i = 0; i < 1000; i++) {
        if (fgets(synsets[i], 100, fp1) != NULL)
            ;
        strtok(synsets[i], "\n");
    }
    fclose(fp1);
    return 0;
}

void writeData(float* output, char synsetWords[1000][100], Mat &frame) {
    int top5[5], j;
    
    top(output, top5);
    
    copyMakeBorder(frame, frame, 0, 0, 400, 0, BORDER_CONSTANT, CV_RGB(0,0,0));
    char strbuf[50];
    sprintf(strbuf, "%4.1f%% %s", output[top5[0]]*100, synsetWords[top5[0]]);
    putText(frame, strbuf, cvPoint(30,80), CV_FONT_HERSHEY_DUPLEX, 1.0, CV_RGB(220,220,220), 1);
    sprintf(strbuf, "%4.1f%% %s", output[top5[1]]*100, synsetWords[top5[1]]);
    putText(frame, strbuf, cvPoint(30,130), CV_FONT_HERSHEY_DUPLEX, 1.0, CV_RGB(220,220,220), 1);
    sprintf(strbuf, "%4.1f%% %s", output[top5[2]]*100, synsetWords[top5[2]]);
    putText(frame, strbuf, cvPoint(30,180), CV_FONT_HERSHEY_DUPLEX, 1.0, CV_RGB(220,220,220), 1);
    sprintf(strbuf, "%4.1f%% %s", output[top5[3]]*100, synsetWords[top5[3]]);
    putText(frame, strbuf, cvPoint(30,230), CV_FONT_HERSHEY_DUPLEX, 1.0, CV_RGB(220,220,220), 1);
    sprintf(strbuf, "%4.1f%% %s", output[top5[4]]*100, synsetWords[top5[4]]);
    putText(frame, strbuf, cvPoint(30,280), CV_FONT_HERSHEY_DUPLEX, 1.0, CV_RGB(220,220,220), 1);

}

// Main function
int main(int argc, char* argv[]) {
    int n = 1;
    char synsetWords[1000][100];
    
    namedWindow("Classification with ResNet-50",CV_WINDOW_NORMAL);
    resizeWindow("Classification with ResNet-50",440,224);

    Mat im;
    im = imread(argv[1], 1);
    
    float* ipfBuffer = (float*)calloc(sizeof(float), 224*224*3);
    
    float* opBuffer = (float*)calloc(sizeof(float), 1000);
    if (argc != 2) {
        printf("Input image missing \nSample Usage-./resnet_exe image.png\n");
        exit(1);
    }
    if (prepareSynset(synsetWords) == -1) {
        printf("ERROR: Unable to find synsetWords.txt\n");
        return -1;
    }

    //read input imaget to the ipfBuffer
    readData(ipfBuffer, argv[1]);
    
    //run prediction on image stored in ipfBuffer
    resnet_predict(ipfBuffer, opBuffer);
    
    //write predictions on input image
    writeData(opBuffer, synsetWords, im);

    //show predictions on input image
    imshow("Classification with ResNet-50", im);
    waitKey(5000);
    destroyWindow("Classification with ResNet-50");
    return 0;
}

Создание и запуск исполняемого файла

Создайте исполняемый файл на основе целевой платформы. На платформе Windows этот пример использует Microsoft® Visual Studio® 2017 для C++.

if ispc
    setenv('MATLAB_ROOT', matlabroot);
    system('make_mkldnn_win17.bat');
    system('resnet.exe peppers.png');
else
    setenv('MATLAB_ROOT', matlabroot);
    system('make -f Makefile_mkldnn_linux.mk');
    system('./resnet_exe peppers.png');
end

Результаты MEX-функции не могут совпадать с результатами сгенерированной статической библиотечной функции из-за различий в версии библиотеки, которой пользуются, чтобы считать входной файл изображения. Изображение, которое передается MEX-функции, читается с помощью версии, которую поставляет MATLAB. Изображение, которое передается статической библиотечной функции, читается с помощью версии, которую использует OpenCV.

Смотрите также

| | |

Похожие темы