Решите обыкновенное дифференциальное уравнение Используя нейронную сеть

В этом примере показано, как решить обыкновенное дифференциальное уравнение (ODE) с помощью нейронной сети.

Не все дифференциальные уравнения имеют решение закрытой формы. Чтобы найти приближенные решения этих типов уравнений, много традиционных числовых алгоритмов доступны. Однако можно также решить ОДУ при помощи нейронной сети. Этот подход идет с несколькими преимуществами, включая которые он предоставляет дифференцируемые приближенные решения в закрытой аналитической форме [1].

Этот пример показывает вам как:

  1. Сгенерируйте обучающие данные в области значений x[0,2].

  2. Задайте нейронную сеть, которая берет x как введено и возвращает приближенное решение ОДУ y˙=-2xy, оцененный в x, как выведено.

  3. Обучите сеть с пользовательской функцией потерь.

  4. Сравните сетевые предсказания с аналитическим решением.

ОДУ и функция потерь

В этом примере вы решаете ОДУ

y˙=-2xy,

с начальным условием y(0)=1. Это ОДУ имеет аналитическое решение

y(x)=e-x2.

Задайте пользовательскую функцию потерь, которая штрафует отклонения от удовлетворения ОДУ и начальному условию. В этом примере функция потерь является взвешенной суммой потери ОДУ и начальной потери условия:

Lθ(x)=y˙θ+2xyθ2+kyθ(0)-12

θсетевые параметры, k постоянный коэффициент, yθ решение, предсказанное сетью, и yθ˙ производная предсказанного решения, вычисленного с помощью автоматического дифференцирования. Термин yθ˙+2xyθ2 потеря ОДУ, и она определяет количество, сколько предсказанное решение отклоняет от удовлетворения определению ОДУ. Термин yθ(0)-12 начальная потеря условия, и она определяет количество, сколько предсказанное решение отклоняет от удовлетворения начальному условию.

Сгенерируйте входные данные и сеть Define

Сгенерируйте 10 000 точек обучающих данных в области значений x[0,2].

x = linspace(0,2,10000)';

Задайте сеть для того, чтобы решить ОДУ. Когда вход является вещественным числом xR, задайте входной размер 1.

inputSize = 1;
layers = [
    featureInputLayer(inputSize,Normalization="none")
    fullyConnectedLayer(10)
    sigmoidLayer
    fullyConnectedLayer(1)
    sigmoidLayer];

Создайте dlnetwork объект от массива слоя.

dlnet = dlnetwork(layers)
dlnet = 
  dlnetwork with properties:

         Layers: [5×1 nnet.cnn.layer.Layer]
    Connections: [4×2 table]
     Learnables: [4×3 table]
          State: [0×3 table]
     InputNames: {'input'}
    OutputNames: {'layer_2'}
    Initialized: 1

Функция градиентов модели Define

Создайте функциональный modelGradients, перечисленный в конце примера, который берет в качестве входных параметров dlnetwork объект dlnet, мини-пакет входных данных dlX, и коэффициент сопоставил с начальной потерей условия icCoeff. Эта функция возвращает градиенты потери относительно настраиваемых параметров в dlnet и соответствующая потеря.

Задайте опции обучения

Обучайтесь в течение 15 эпох с мини-пакетным размером 100.

numEpochs = 15;
miniBatchSize = 100;

Задайте опции для оптимизации SGDM. Задайте скорость обучения 0,5, фактор отбрасывания скорости обучения 0,5, период отбрасывания скорости обучения 5 и импульс 0,9.

initialLearnRate = 0.5;
learnRateDropFactor = 0.5;
learnRateDropPeriod = 5;
momentum = 0.9;

Задайте коэффициент начального термина условия в функции потерь как 7. Этот коэффициент задает относительный вклад начального условия к потере. Тонкая настройка этого параметра может помочь обучению сходиться быстрее.

icCoeff = 7;

Обучите модель

Использовать мини-пакеты данных во время обучения:

  1. Создайте arrayDatastore объект от обучающих данных.

  2. Создайте minibatchqueue объект, который берет arrayDatastore возразите, как введено, задайте мини-пакетный размер и отформатируйте обучающие данные с размерностью, маркирует 'BC' (пакет, канал).

ads = arrayDatastore(x,IterationDimension=1);
mbq = minibatchqueue(ads,MiniBatchSize=miniBatchSize,MiniBatchFormat="BC");

По умолчанию, minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single.

Обучайтесь на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

Инициализируйте график процесса обучения.

figure
set(gca,YScale="log")
lineLossTrain = animatedline(Color=[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss (log scale)")
grid on

Инициализируйте скоростной параметр для решателя SGDM.

velocity = [];

Обучите сеть с помощью пользовательского учебного цикла. В течение каждой эпохи переставьте данные и цикл по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели и потерю с помощью dlfeval и modelGradients функции.

  • Обновите сетевые параметры с помощью sgdmupdate функция.

  • Отобразите прогресс обучения.

Каждый learnRateDropPeriod эпохи, умножьте скорость обучения на learnRateDropFactor.

iteration = 0;
learnRate = initialLearnRate;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    mbq.shuffle

    % Loop over mini-batches.
    while hasdata(mbq)

        iteration = iteration + 1;

        % Read mini-batch of data.
        dlX = next(mbq);

        % Evaluate the model gradients and loss using dlfeval and the modelGradients function.
        [gradients,loss] = dlfeval(@modelGradients, dlnet, dlX, icCoeff);

        % Update network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);

        % To plot, convert the loss to double.
        loss = double(gather(extractdata(loss)));
        
        % Display the training progress.
        D = duration(0,0,toc(start),Format="mm:ss.SS");
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + " of " + numEpochs + ", Elapsed: " + string(D))
        drawnow

    end
    % Reduce the learning rate.
    if mod(epoch,learnRateDropPeriod)==0
        learnRate = learnRate*learnRateDropFactor;
    end
end

Тестовая модель

Протестируйте точность сети путем сравнения ее предсказаний с аналитическим решением.

Сгенерируйте тестовые данные в области значений x[0,4] видеть, может ли сеть экстраполировать вне учебной области значений x[0,2].

xTest = linspace(0,4,1000)';

Использовать мини-пакеты данных во время тестирования:

  1. Создайте arrayDatastore объект из данных о тестировании.

  2. Создайте minibatchqueue объект, который берет arrayDatastore возразите, как введено, задайте мини-пакетный размер 100 и отформатируйте обучающие данные с размерностью, маркирует 'BC' (пакет, канал).

adsTest = arrayDatastore(xTest,IterationDimension=1);
mbqTest = minibatchqueue(adsTest,MiniBatchSize=100,MiniBatchFormat="BC");

Цикл по мини-пакетам и делает предсказания с помощью modelPredictions функция, перечисленная в конце примера.

yModel = modelPredictions(dlnet,mbqTest);

Оцените аналитическое решение.

yAnalytic = exp(-xTest.^2);

Сравните аналитическое решение и предсказание модели путем графического вывода их на логарифмическом масштабе.

figure;
plot(xTest,yAnalytic,"-")
hold on
plot(xTest,yModel,"--")
legend("Analytic","Model")
xlabel("x")
ylabel("y (log scale)")
set(gca,YScale="log")

Модель аппроксимирует аналитическое решение точно в учебной области значений x[0,2] и это экстраполирует в области значений x(2,4] с более низкой точностью.

Вычислите среднеквадратическую относительную погрешность учебной области значений x[0,2].

yModelTrain = yModel(1:500);
yAnalyticTrain = yAnalytic(1:500);

errorTrain = mean(((yModelTrain-yAnalyticTrain)./yAnalyticTrain).^2) 
errorTrain = single
    4.3454e-04

Вычислите среднеквадратическую относительную погрешность экстраполируемой области значений x(2,4].

yModelExtra = yModel(501:1000);
yAnalyticExtra = yAnalytic(501:1000);

errorExtra = mean(((yModelExtra-yAnalyticExtra)./yAnalyticExtra).^2) 
errorExtra = single
    17576612

Заметьте, что среднеквадратическая относительная погрешность выше в экстраполируемой области значений, чем это находится в учебной области значений.

Функция градиентов модели

modelGradients функционируйте берет в качестве входных параметров dlnetwork объект dlnet, мини-пакет входных данных dlX, и коэффициент сопоставил с начальной потерей условия icCoeff. Эта функция возвращает градиенты потери относительно настраиваемых параметров в dlnet и соответствующая потеря. Потеря задана как взвешенная сумма потери ОДУ и начальной потери условия. Оценка этой потери требует производных второго порядка. Чтобы включить второму порядку автоматическое дифференцирование, используйте функциональный dlgradient и набор EnableHigherDerivatives аргумент значения имени к true.

function [gradients,loss] = modelGradients(dlnet, dlX, icCoeff)
y = forward(dlnet,dlX);

% Evaluate the gradient of y with respect to x. 
% Since another derivative will be taken, set EnableHigherDerivatives to true.
dy = dlgradient(sum(y,"all"),dlX,EnableHigherDerivatives=true);

% Define ODE loss.
eq = dy + 2*y.*dlX;

% Define initial condition loss.
ic = forward(dlnet,dlarray(0,"CB")) - 1;

% Specify the loss as a weighted sum of the ODE loss and the initial condition loss.
loss = mean(eq.^2,"all") + icCoeff * ic.^2;

% Evaluate model gradients.
gradients = dlgradient(loss, dlnet.Learnables);

end

Функция предсказаний модели

modelPredictions функционируйте берет dlnetwork объект dlnet и minibatchqueue из входных данных mbq и вычисляет предсказания модели y путем итерации по всем данным в minibatchqueue объект.

function Y = modelPredictions(dlnet,mbq)

Y = [];

while hasdata(mbq)

    % Read mini-batch of data.
    dlXTest = next(mbq);
    
    % Predict output using trained network.
    dlY = predict(dlnet,dlXTest);
    YPred = gather(extractdata(dlY));
    Y = [Y; YPred'];

end

end

Ссылки

  1. Lagaris, Т.е. А. Ликас и Д. Ай. Фотиэдис. “Искусственные Нейронные сети для Решения Обыкновенных дифференциальных уравнений и Дифференциальных уравнений с частными производными”. Транзакции IEEE на Нейронных сетях 9, № 5 (сентябрь 1998): 987–1000. https://doi.org/10.1109/72.712178.

Смотрите также

| |

Похожие темы