exponenta event banner

Создание общего кода C/C + + для регрессии последовательности, использующей глубокое обучение

В этом примере показано, как создать простой код C/C + +, который не зависит от каких-либо сторонних библиотек глубокого обучения для сети долговременной краткосрочной памяти (LSTM). Создается функция MEX, которая принимает данные временных рядов, представляющие различные датчики в механизме. Затем функция MEX делает прогнозы для каждого шага входных временных рядов для прогнозирования оставшегося срока службы (RUL) двигателя, измеренного в циклах.

В этом примере используется набор данных моделирования деградации турбовентиляторного двигателя, как описано в [1], и предварительно обученная сеть LSTM для прогнозирования оставшегося срока службы двигателя. Сеть была обучена моделируемым данным последовательности временных рядов для 100 двигателей и соответствующим значениям оставшегося срока службы в конце каждой последовательности. Каждая последовательность в этих учебных данных имеет различную длину и соответствует экземпляру полного прогона до отказа (RTF). Дополнительные сведения об обучении сети см. в примере Регрессия последовательности с использованием глубокого обучения (панель инструментов глубокого обучения).

Определение функции точки входа rulPredict

rulPredict функция точки входа принимает входную последовательность и передает ее в обученную сеть LSTM последовательности к последовательности для прогнозирования. Функция загружает сетевой объект из rulNetwork.mat в постоянную переменную и повторно использует постоянный объект при последующих вызовах прогнозирования. Сеть LSTM делает прогнозы по частичной последовательности один шаг за один раз. На каждом временном шаге сеть прогнозирует использование значения на этом временном шаге и состояние сети, вычисленное только из предыдущих временных шагов. Сеть обновляет свое состояние между каждым предсказанием. predict функция возвращает последовательность этих прогнозов. Последний элемент предсказания соответствует предсказанному RUL для частичной последовательности.

Для отображения интерактивной визуализации сетевой архитектуры и информации о сетевых уровнях используйте analyzeNetwork (Deep Learning Toolbox).

type rulPredict.m
function out = rulPredict(in)
%#codegen

% Copyright 2020 The MathWorks, Inc. 

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork('rulNetwork.mat');
end

% pass in input to predict method
% To prevent the function from adding padding to the data, specify the mini-batch size 1. 
out = predict(mynet,in,'MiniBatchSize',1);

Управляемый rulPredict на тестовых данных

Загрузить TurboFanRULValidate MAT-файл. В этом MAT-файле хранится переменная XValidate содержит образцы данных временных интервалов для показаний датчиков, которые используются для тестирования функции точки входа в MATLAB. Выполните прогнозы в отношении тестовых данных путем вызова rulPredict способ.

load TurboFanRULValidate.mat
YPred = rulPredict(XValidate);

Визуализация некоторых предсказаний на графике.

idx = randperm(numel(YPred),4);
figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPred{idx(i)},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted.

Для данной частичной последовательности предсказанный текущий RUL является последним элементом предсказанных последовательностей. Вычислите среднеквадратическую ошибку (RMSE) предсказаний и визуализируйте ошибку предсказания в гистограмме.

YValidateLast = zeros(1, numel(YValidate));
YPredLast = zeros(1, numel(YValidate));
for i = 1:numel(YValidate)
    YValidateLast(i) = YValidate{i}(end);
    YPredLast(i) = YPred{i}(end);
end
figure
rmse = sqrt(mean((YPredLast - YValidateLast).^2))
rmse = 19.0286
histogram(YPredLast - YValidateLast)
title("RMSE = " + rmse)
ylabel("Frequency")
xlabel("Error")

Figure contains an axes. The axes with title RMSE = 19.0286 contains an object of type histogram.

Создание функции MEX для rulPredict

Создание функции MEX для rulPredict функция точки входа, создание объекта конфигурации генерации кода cfg для генерации кода MEX. Создайте объект конфигурации глубокого обучения, который указывает, что целевая библиотека не требуется, и присоедините этот объект конфигурации глубокого обучения к cfg.

cfg = coder.config('mex');
cfg.DeepLearningConfig = coder.DeepLearningConfig('TargetLibrary','none');

По умолчанию для целевого языка установлено значение C. Если требуется создать код C++, явно задайте для целевого языка значение C++.

Используйте coder.typeof для создания типа ввода для функции точки входа rulPredict которые вы используете с -args опции в codegen команда.

Данные XValidate содержит 100 наблюдений, где каждое наблюдение имеет двойной тип данных со значением размера элемента 17 и переменной длиной последовательности. Для выполнения прогнозирования нескольких таких наблюдений в одном вызове функции можно сгруппировать наблюдения в массив ячеек и передать массив ячеек для прогнозирования. Массив ячеек должен быть массивом ячеек столбцов, и каждая ячейка должна содержать одно наблюдение. Каждое наблюдение должно иметь один и тот же размер элемента, но длина последовательности может изменяться, как в случае XValidate. Определение длины последовательности как переменной величины позволяет выполнять прогнозирование для входной последовательности любой длины.

matrixInput = coder.typeof(0, [17 Inf],[false true]); % input type for a single observation
cellInput = coder.typeof({matrixInput}, [100 1]); % input type for multiple observations 

Выполните команду codegen. Укажите тип ввода для cellInput.

codegen -config cfg rulPredict -args {cellInput} -report
Code generation successful: To view the report, open('codegen/mex/rulPredict/html/report.mldatx').

По умолчанию для генерации кода MEX сгенерированный код вызывает библиотеку BLAS для матричных операций и использует библиотеку OpenMP (если компилятор поддерживает OpenMP), так что любой параллелизуемый для циклов в MEX может выполняться на нескольких потоках, что приводит к лучшей производительности выполнения. Хотя OpenMP включен по умолчанию для создания автономного кода, необходимо предоставить пользовательский обратный вызов BLAS, чтобы указать ™ кодера MATLAB, что требуется генерировать вызовы BLAS для операций матрицы, следуя шагам, упомянутым в разделе Ускорение операций матрицы в сгенерированном автономном коде с помощью вызовов BLAS.

Выполнение сгенерированной функции MEX для тестовых данных

Прогнозирование тестовых данных путем вызова сгенерированной функции MEX rulPredict_mex.

YPredMex = rulPredict_mex(XValidate);

Можно визуализировать те же прогнозы, что и ранее на графике.

figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPredMex{idx(i)},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted MEX"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted MEX.

Вычислите среднеквадратическую ошибку (RMSE) предсказаний и визуализируйте ошибку предсказания в гистограмме.

YPredLastMex = zeros(1, numel(YValidate));
for i = 1:numel(YValidate)
    YPredLastMex(i) = YPredMex{i}(end);
end
figure
rmse = sqrt(mean((YPredLastMex - YValidateLast).^2))
rmse = 19.0286
histogram(YPredLastMex - YValidateLast)
title("RMSE = " + rmse)
ylabel("Frequency")
xlabel("Error")

Figure contains an axes. The axes with title RMSE = 19.0286 contains an object of type histogram.

Создание функции MEX с помощью LSTM с учетом состояния

Вместо передачи всех временных рядов для прогнозирования в один шаг, можно сделать прогнозы один шаг за один раз с помощью predictAndUpdateState (инструментарий глубокого обучения). Это полезно при наличии значений временных шагов, поступающих в поток. predictAndUpdateState функция принимает входные данные, производит прогнозирование выходных данных и обновляет внутреннее состояние сети так, чтобы будущие прогнозы учитывали эти исходные данные. Обычно быстрее делать прогнозы по полным последовательностям по сравнению с тем, чтобы делать прогнозы по одному шагу за раз.

Функция точки входа rulPredictAndUpdate принимает ввод в один временной интервал и обрабатывает ввод с помощью predictAndUpdateState функция. predictAndUpdateState выводит прогноз для входного временного интервала и обновляет сеть так, чтобы последующие входные сигналы рассматривались как последующие временные интервалы той же выборки. После прохождения всех временных интервалов по одному, результирующий выходной сигнал будет таким же, как если бы все временные интервалы были переданы как один вход.

type rulPredictAndUpdate.m
function out = rulPredictAndUpdate(in)
%#codegen

% Copyright 2020 The MathWorks, Inc. 

persistent mynet;

if isempty(mynet)
    mynet = coder.loadDeepLearningNetwork('rulNetwork.mat');
end

% pass in input to predictAndUpdateState method
[mynet, out] = predictAndUpdateState(mynet, in);

Запустите кодеген для этой новой функции точки входа. Поскольку мы принимаем один временной интервал для каждого вызова, мы указываем matrixInput чтобы иметь фиксированную размерность последовательности 1 вместо переменной длины последовательности.

matrixInput = coder.typeof(double(0),[17 1]);
codegen -config cfg rulPredictAndUpdate -args {matrixInput} -report
Code generation successful: To view the report, open('codegen/mex/rulPredictAndUpdate/html/report.mldatx').

Выполните прогнозы в отношении тестовых данных путем вызова rulPredictAndUpdate функция в MATLAB and the сгенерированная функция MEX rulPredictAndUpdate_mex.

YPredStatefulMex = cell(numel(idx), 1);
for iSample = 1:numel(idx)
    sample = XValidate{idx(iSample)};
    numTimeStepsTest = size(sample, 2);
    for iStep = 1:numTimeStepsTest
        YPredStatefulMex{iSample}(1, iStep) = rulPredictAndUpdate_mex(sample(:, iStep));
    end
end

Вы снова можете визуализировать прогнозы для MEX, как раньше на графике.

figure
for i = 1:numel(idx)
    subplot(2,2,i)
    
    plot(YValidate{idx(i)},'--')
    hold on
    plot(YPredStatefulMex{i},'.-')
    hold off
    
    ylim([0 175])
    title("Test Observation " + idx(i))
    xlabel("Time Step")
    ylabel("RUL")
end
legend(["Test Data" "Predicted MEX Stateful LSTM"],'Location','southeast')

Figure contains 4 axes. Axes 1 with title Test Observation 82 contains 2 objects of type line. Axes 2 with title Test Observation 90 contains 2 objects of type line. Axes 3 with title Test Observation 13 contains 2 objects of type line. Axes 4 with title Test Observation 89 contains 2 objects of type line. These objects represent Test Data, Predicted MEX Stateful LSTM.

Наконец, можно также визуализировать результаты для двух различных функций MEX вместе с прогнозом MATLAB на графике для любой конкретной выборки.

figure()
sampleIdx = idx(1);
plot(YValidate{sampleIdx},'--')
hold on
plot(YPred{sampleIdx},'o-')
plot(YPredMex{sampleIdx},'^-')
plot(YPredStatefulMex{1},'x-')
hold off

ylim([0 175])
title("Test Observation " + idx(i))
xlabel("Time Step")
ylabel("RUL")
legend(["Test Data" "Predicted in MATLAB" "Predicted MEX" "Predicted MEX with Stateful LSTM"],'Location','southeast')

Figure contains an axes. The axes with title Test Observation 89 contains 4 objects of type line. These objects represent Test Data, Predicted in MATLAB, Predicted MEX, Predicted MEX with Stateful LSTM.

Ссылки

  1. Саксена, Абхинав, Кай Гебель, Дон Симон и Нил Эклунд. «Моделирование распространения повреждений для имитации обкатки двигателя самолета». В Prognostics and Health Management, 2008. PHM 2008. Международная конференция, стр. 1-9. IEEE, 2008.

См. также

| |

Связанные темы