Обучите агента DDPG к Swing и маятнику баланса с наблюдением изображений

В этом примере показано, как обучить агента глубоко детерминированного градиента политики (DDPG) качаться и балансировать маятник с наблюдения изображений, смоделированного в MATLAB®.

Для получения дополнительной информации об агентах DDPG смотрите Глубоко Детерминированных Агентов Градиента политики (Reinforcement Learning Toolbox).

Математический маятник с изображением среда MATLAB

Среда обучения с подкреплением для этого примера является простым лишенным трения маятником, который первоначально висит в нисходящем положении. Цель обучения должна заставить маятник стоять вертикально, не падая и используя минимальные усилия по управлению.

Для этой среды:

  • Восходящим сбалансированным положением маятника является 0 радианами и нисходящим положением зависания является pi радианы.

  • Сигнал действия крутящего момента от агента до среды от –2 до 2 Н · m.

  • Наблюдения средой являются изображением, указывающим на местоположение массы маятника и скорости вращения маятника.

  • Вознаграждение rt, если на каждом временном шаге,

rt=-(θt2+0.1θt˙2+0.001ut-12)

Здесь:

  • θt угол смещения от вертикального положения.

  • θt˙ производная угла рассогласования.

  • ut-1 усилие по управлению от предыдущего временного шага.

Для получения дополнительной информации об этой модели смотрите Загрузку Предопределенные Среды Системы управления (Reinforcement Learning Toolbox).

Создайте интерфейс среды

Создайте предопределенный интерфейс среды для маятника.

env = rlPredefinedEnv('SimplePendulumWithImage-Continuous')
env = 
  SimplePendlumWithImageContinuousAction with properties:

             Mass: 1
        RodLength: 1
       RodInertia: 0
          Gravity: 9.8100
     DampingRatio: 0
    MaximumTorque: 2
               Ts: 0.0500
            State: [2x1 double]
                Q: [2x2 double]
                R: 1.0000e-03

Интерфейс имеет непрерывное пространство действий, где агент может применить крутящий момент между –2 к 2 Н · m.

Получите спецификацию наблюдений и спецификацию действия от интерфейса среды.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Для повторяемости результатов зафиксируйте начальное значение генератора случайных чисел.

rng(0)

Создайте агента DDPG

Агент DDPG аппроксимирует долгосрочное вознаграждение, заданные наблюдения и действия, с помощью представления функции ценности критика. Чтобы создать критика, сначала создайте глубокую сверточную нейронную сеть (CNN) с тремя входными параметрами (изображение, скорость вращения и действие) и один выход. Для получения дополнительной информации о создании представлений смотрите, Создают политику и Представления Функции ценности (Reinforcement Learning Toolbox).

hiddenLayerSize1 = 400;
hiddenLayerSize2 = 300;

imgPath = [
    imageInputLayer(obsInfo(1).Dimension,'Normalization','none','Name',obsInfo(1).Name)
    convolution2dLayer(10,2,'Name','conv1','Stride',5,'Padding',0)
    reluLayer('Name','relu1')
    fullyConnectedLayer(2,'Name','fc1')
    concatenationLayer(3,2,'Name','cat1')
    fullyConnectedLayer(hiddenLayerSize1,'Name','fc2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(hiddenLayerSize2,'Name','fc3')
    additionLayer(2,'Name','add')
    reluLayer('Name','relu3')
    fullyConnectedLayer(1,'Name','fc4')
    ];
dthetaPath = [
    imageInputLayer(obsInfo(2).Dimension,'Normalization','none','Name',obsInfo(2).Name)
    fullyConnectedLayer(1,'Name','fc5','BiasLearnRateFactor',0,'Bias',0)
    ];
actPath =[
    imageInputLayer(actInfo(1).Dimension,'Normalization','none','Name','action')
    fullyConnectedLayer(hiddenLayerSize2,'Name','fc6','BiasLearnRateFactor',0,'Bias',zeros(hiddenLayerSize2,1))
    ];

criticNetwork = layerGraph(imgPath);
criticNetwork = addLayers(criticNetwork,dthetaPath);
criticNetwork = addLayers(criticNetwork,actPath);
criticNetwork = connectLayers(criticNetwork,'fc5','cat1/in2');
criticNetwork = connectLayers(criticNetwork,'fc6','add/in2');

Просмотрите конфигурацию сети критика.

figure
plot(criticNetwork)

Figure contains an axes object. The axes object contains an object of type graphplot.

Задайте опции для представления критика с помощью rlRepresentationOptions (Reinforcement Learning Toolbox).

criticOptions = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);

Не прокомментируйте следующую линию, чтобы использовать графический процессор, чтобы ускорить обучение CNN критика. Для получения дополнительной информации о поддерживаемых графических процессорах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

% criticOptions.UseDevice = 'gpu';

Создайте представление критика с помощью заданной нейронной сети и опций. Необходимо также задать информацию о действии и наблюдении для критика, которого вы получаете из интерфейса среды. Для получения дополнительной информации смотрите rlQValueRepresentation (Reinforcement Learning Toolbox).

critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,...
    'Observation',{'pendImage','angularRate'},'Action',{'action'},criticOptions);

Агент DDPG решает который действие взять заданные наблюдения с помощью представления актера. Чтобы создать агента, сначала создайте глубокую сверточную нейронную сеть (CNN) с двумя входными параметрами (изображение и скорость вращения) и один выход (действие).

Создайте агента подобным образом критику.

imgPath = [
    imageInputLayer(obsInfo(1).Dimension,'Normalization','none','Name',obsInfo(1).Name)
    convolution2dLayer(10,2,'Name','conv1','Stride',5,'Padding',0)
    reluLayer('Name','relu1')
    fullyConnectedLayer(2,'Name','fc1')
    concatenationLayer(3,2,'Name','cat1')
    fullyConnectedLayer(hiddenLayerSize1,'Name','fc2')
    reluLayer('Name','relu2')
    fullyConnectedLayer(hiddenLayerSize2,'Name','fc3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(1,'Name','fc4')
    tanhLayer('Name','tanh1')
    scalingLayer('Name','scale1','Scale',max(actInfo.UpperLimit))
    ];
dthetaPath = [
    imageInputLayer(obsInfo(2).Dimension,'Normalization','none','Name',obsInfo(2).Name)
    fullyConnectedLayer(1,'Name','fc5','BiasLearnRateFactor',0,'Bias',0)
    ];

actorNetwork = layerGraph(imgPath);
actorNetwork = addLayers(actorNetwork,dthetaPath);
actorNetwork = connectLayers(actorNetwork,'fc5','cat1/in2');

actorOptions = rlRepresentationOptions('LearnRate',1e-04,'GradientThreshold',1);

Не прокомментируйте следующую линию, чтобы использовать графический процессор, чтобы ускорить обучение CNN агента.

% actorOptions.UseDevice = 'gpu';

Создайте представление актера с помощью заданной нейронной сети и опций. Для получения дополнительной информации смотрите rlDeterministicActorRepresentation (Reinforcement Learning Toolbox).

actor = rlDeterministicActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'pendImage','angularRate'},'Action',{'scale1'},actorOptions);

Просмотрите конфигурацию сети агента.

figure
plot(actorNetwork)

Figure contains an axes object. The axes object contains an object of type graphplot.

Чтобы создать агента DDPG, сначала задайте опции агента DDPG с помощью rlDDPGAgentOptions (Reinforcement Learning Toolbox).

agentOptions = rlDDPGAgentOptions(...
    'SampleTime',env.Ts,...
    'TargetSmoothFactor',1e-3,...
    'ExperienceBufferLength',1e6,...
    'DiscountFactor',0.99,...
    'MiniBatchSize',128);
agentOptions.NoiseOptions.Variance = 0.6;
agentOptions.NoiseOptions.VarianceDecayRate = 1e-6;

Затем создайте агента с помощью заданного представления актера, представления критика и опций агента. Для получения дополнительной информации смотрите rlDDPGAgent (Reinforcement Learning Toolbox).

agent = rlDDPGAgent(actor,critic,agentOptions);

Обучите агента

Чтобы обучить агента, сначала задайте опции обучения. В данном примере используйте следующие опции.

  • Запустите каждое обучение самое большее 5 000 эпизодов с каждым эпизодом, длящимся самое большее 400 временных шагов.

  • Отобразите прогресс обучения в диалоговом окне Episode Manager (установите Plots опция).

  • Остановите обучение, когда агент получит скользящее среднее значение совокупное вознаграждение, больше, чем-740 более чем десять последовательных эпизодов. На данном этапе агент может быстро сбалансировать маятник в вертикальном положении с помощью минимального усилия по управлению.

Для получения дополнительной информации смотрите rlTrainingOptions (Reinforcement Learning Toolbox).

maxepisodes = 5000;
maxsteps = 400;
trainingOptions = rlTrainingOptions(...
    'MaxEpisodes',maxepisodes,...
    'MaxStepsPerEpisode',maxsteps,...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',-740);

Можно визуализировать маятник при помощи plot функция во время обучения или симуляции.

plot(env)

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

Обучите агента с помощью train (Reinforcement Learning Toolbox) функция. Обучение этот агент является в вычислительном отношении интенсивным процессом, который занимает несколько часов, чтобы завершиться. Чтобы сэкономить время при выполнении этого примера, загрузите предварительно обученного агента установкой doTraining к false. Чтобы обучить агента самостоятельно, установите doTraining к true.

doTraining = false;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainingOptions);
else
    % Load pretrained agent for the example.
    load('SimplePendulumWithImageDDPG.mat','agent')       
end

Симулируйте агента DDPG

Чтобы подтвердить производительность обученного агента, симулируйте его в среде маятника. Для получения дополнительной информации о симуляции агента смотрите rlSimulationOptions (Reinforcement Learning Toolbox) и sim (Reinforcement Learning Toolbox).

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

Смотрите также

(Reinforcement Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте