Сгенерируйте генеративный код C/C + + для регрессии «последовательность-последовательность», которая использует глубокое обучение

Этот пример демонстрирует, как сгенерировать простой код C/C + +, который не зависит ни от каких сторонних библиотек глубокого обучения для сети долгой краткосрочной памяти (LSTM). Вы генерируете MEX-функцию, которая принимает данные временных рядов, представляющие различные датчики в двигателе. Затем MEX-функция делает предсказания для каждого шага входа timeseries, чтобы предсказать оставшийся срок службы (RUL) двигателя, измеренный в циклах.

Этот пример использует набор данных моделирования деградации Engine Turbofan, как описано в [1], и предварительно обученную сеть LSTM, чтобы предсказать оставшийся срок полезного использования двигателя. Сеть обучалась на моделируемых данных временных рядов для 100 двигателей и соответствующих значений оставшегося срока службы в конце каждой последовательности. Каждая последовательность в этих обучающих данных имеет разную длину и соответствует полному прогону в отказ (RTF) образца. Для получения дополнительной информации о обучении сети смотрите пример Регрессия последовательность-в-последовательности с использованием глубокого обучения.

Задайте функцию точки входа rulPredict

The rulPredict функция точки входа принимает входную последовательность и передает ее в обученную сеть LSTM последовательности в последовательности для предсказания. Функция загружает сетевой объект из rulNetwork.mat файл в постоянную переменную и повторно использует постоянный объект при последующих вызовах предсказания. Сеть LSTM делает предсказания о частичной последовательности один временной шаг за раз. На каждом временном шаге сеть прогнозирует использование значения на этом временном шаге и состояние сети, рассчитанное только с предыдущих временных шагов. Сеть обновляет свое состояние между каждым предсказанием. The predict функция возвращает последовательность этих предсказаний. Последний элемент предсказания соответствует предсказанному RUL для частичной последовательности.

Чтобы отобразить интерактивную визуализацию сетевой архитектуры и информацию о слоях сети, используйте analyzeNetwork функция.

type rulPredict.m
function out = rulPredict(in)
%#codegen

% Copyright 2020 The MathWorks, Inc. 

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork('rulNetwork.mat');
end

% pass in input to predict method
% To prevent the function from adding padding to the data, specify the mini-batch size 1. 
out = predict(mynet,in,'MiniBatchSize',1);

Выполняйте rulPredict по тестовым данным

Загрузите TurboFanRULValidate MAT-файл. Этот MAT-файл хранит переменную XValidate который содержит выборку данных timeseries для показаний датчика, которые вы используете для тестирования функции точки входа в MATLAB. Делайте предсказания на тестовых данных, вызывая rulPredict способ.

load TurboFanRULValidate.mat
YPred = rulPredict(XValidate);

Визуализируйте некоторые предсказания на графике.

idx = randperm(numel(YPred),4);
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPred{idx(i)},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted.

Для заданной частичной последовательности предсказанный текущий RUL является последним элементом предсказанных последовательностей. Вычислите среднеквадратичную ошибку (RMSE) предсказаний и визуализируйте ошибку предсказания в гистограмме.

YValidateLast = zeros(1, numel(YValidate));
YPredLast = zeros(1, numel(YValidate));
for i = 1:numel(YValidate)
    YValidateLast(i) = YValidate{i}(end);
    YPredLast(i) = YPred{i}(end);
end
figure
rmse = sqrt(mean((YPredLast - YValidateLast).^2))
rmse = 19.0286
histogram(YPredLast - YValidateLast)
title("RMSE = " + rmse)
ylabel("Frequency")
xlabel("Error")

Figure contains an axes. The axes with title RMSE = 19.0286 contains an object of type histogram.

Сгенерируйте MEX-функцию для rulPredict

Чтобы сгенерировать MEX-функцию для rulPredict функция точки входа, создайте объект строения генерации кода cfg для генерации кода MEX. Создайте объект строения глубокого обучения, который задает, что никакая целевая библиотека не требуется, и присоедините этот объект строения глубокого обучения к cfg.

cfg = coder.config('mex');
cfg.DeepLearningConfig = coder.DeepLearningConfig('TargetLibrary','none');

По умолчанию для целевого языка задано значение C. Если вы хотите сгенерировать код С++, явно установите целевой язык на C++.

Используйте coder.typeof функция для создания входного типа для функции точки входа rulPredict которые вы используете с -args опция в codegen команда.

Область данных XValidate содержит 100 наблюдений, где каждое наблюдение имеет двойной тип данных со значением размерности признаков 17 и переменную длину последовательности. В порядок выполнения предсказания нескольких таких наблюдений в одном вызове функции можно сгруппировать наблюдения вместе в массиве ячеек и передать массив ячеек для предсказания. Массив ячеек должен быть столбцом массивом ячеек, и каждая камера должен содержать одно наблюдение. Каждое наблюдение должно иметь одну и ту же размерность признаков, но длины последовательности могут варьироваться, как в случае XValidate. Определение длины последовательности как переменного размера позволяет нам выполнить предсказание на вход последовательности любой длины.

matrixInput = coder.typeof(0, [17 Inf],[false true]); % input type for a single observation
cellInput = coder.typeof({matrixInput}, [100 1]); % input type for multiple observations 

Выполните команду codegen. Задайте входной тип, который будет cellInput.

codegen -config cfg rulPredict -args {cellInput} -report
Code generation successful: To view the report, open('codegen/mex/rulPredict/html/report.mldatx').

По умолчанию для генерации кода MEX сгенерированный код вызывает библиотеку BLAS для матричных операций и использует библиотеку OpenMP (если компилятор поддерживает OpenMP), так что любой параллелизируемый для циклов в MEX может запускаться в нескольких потоках, что приводит к лучшей эффективности выполнения. В то время как OpenMP включен по умолчанию для генерации автономного кода, вы должны будете предоставить пользовательский коллбэк BLAS, чтобы указать ™ MATLAB Coder, что вы хотите сгенерировать вызовы BLAS для матричных операций, следуя шагам, упомянутым в Операции Ускорения Матрицы в Сгенерированном автономном коде при помощи вызовов BLAS AS (MATLAT)

Запуск сгенерированной MEX-функции на тестовых данных

Делайте предсказания на тестовых данных, вызывая сгенерированную MEX-функцию rulPredict_mex.

YPredMex = rulPredict_mex(XValidate);

Можно визуализировать на графике те же предсказания, что и прежде.

figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPredMex{idx(i)},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted MEX"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted MEX.

Вычислите среднеквадратичную ошибку (RMSE) предсказаний и визуализируйте ошибку предсказания в гистограмме.

YPredLastMex = zeros(1, numel(YValidate));
for i = 1:numel(YValidate)
    YPredLastMex(i) = YPredMex{i}(end);
end
figure
rmse = sqrt(mean((YPredLastMex - YValidateLast).^2))
rmse = 19.0286
histogram(YPredLastMex - YValidateLast)
title("RMSE = " + rmse)
ylabel("Frequency")
xlabel("Error")

Figure contains an axes. The axes with title RMSE = 19.0286 contains an object of type histogram.

Сгенерируйте MEX-функцию с Stateful LSTM

Вместо того, чтобы передавать все timeseries, чтобы предсказать за один шаг, можно сделать предсказания по одному временному шагу за раз при помощи predictAndUpdateState. Это полезно при наличии значений временных шагов, поступающих в поток. The predictAndUpdateState функция принимает вход, создает выход предсказание и обновляет внутреннее состояние сети, так что будущие предсказания учитывают этот начальный вход. Обычно быстрее делать предсказания на полных последовательностях по сравнению с тем, чтобы делать предсказания по одному временному шагу за раз.

Функция точки входа rulPredictAndUpdate принимает входной вход в одно время и обрабатывает вход с помощью predictAndUpdateState функция. predictAndUpdateState выводит предсказание для входа временного интервала и обновляет сеть так, чтобы последующие входы рассматривались как последующие временные интервалы той же выборки. После прохождения во всех timesteps по одному за раз, результат выхода такой же, как если бы все timesteps были переданы как один вход.

type rulPredictAndUpdate.m
function out = rulPredictAndUpdate(in)
%#codegen

% Copyright 2020 The MathWorks, Inc. 

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork('rulNetwork.mat');
end

% pass in input to predictAndUpdateState method
[mynet, out] = predictAndUpdateState(mynet, in);

Запустите codegen для этой новой функции точки входа. Поскольку мы принимаем в одно время каждый вызов, мы задаем matrixInput иметь фиксированную размерность последовательности 1 вместо переменной длины последовательности.

matrixInput = coder.typeof(double(0),[17 1]);
codegen -config cfg rulPredictAndUpdate -args {matrixInput} -report
Code generation successful: To view the report, open('codegen/mex/rulPredictAndUpdate/html/report.mldatx').

Делайте предсказания на тестовых данных, вызывая rulPredictAndUpdate функция в MATLAB an d theсгенерированные MEX-функции rulPredictAndUpdate_mex.

YPredStatefulMex = cell(numel(idx), 1);
for iSample = 1:numel(idx)
    sample = XValidate{idx(iSample)};
    numTimeStepsTest = size(sample, 2);
    for iStep = 1:numTimeStepsTest
        YPredStatefulMex{iSample}(1, iStep) = rulPredictAndUpdate_mex(sample(:, iStep));
    end
end

Еще раз можно визуализировать предсказания для статического MEX как раньше на графике.

figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPredStatefulMex{i},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted MEX Stateful LSTM"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted MEX Stateful LSTM.

Наконец, можно также визуализировать результаты для двух различных MEX-функций наряду с предсказанием MATLAB на графике для любой конкретной выборки.

figure()
sampleIdx = idx(1);
plot(YValidate{sampleIdx},'--')
hold on
plot(YPred{sampleIdx},'o-')
plot(YPredMex{sampleIdx},'^-')
plot(YPredStatefulMex{1},'x-')
hold off

ylim([0 175])
title("Test Observation " + idx(i))
xlabel("Time Step")
ylabel("RUL")
legend(["Test Data" "Predicted in MATLAB" "Predicted MEX" "Predicted MEX with Stateful LSTM"],'Location','southeast')

Figure contains an axes. The axes with title Test Observation 89 contains 4 objects of type line. These objects represent Test Data, Predicted in MATLAB, Predicted MEX, Predicted MEX with Stateful LSTM.

Ссылки

  1. Саксена, Абхинав, Кай Гебель, Дон Симон и Нил Эклунд. Моделирование распространения повреждений для симуляции пробега двигателя самолета до отказа. В прогнозах и управлении здоровьем, 2008. PHM 2008. Международная конференция, стр. 1-9. IEEE, 2008.

См. также

(MATLAB Coder) | (MATLAB Coder) | (MATLAB Coder)

Похожие темы