Обучите агента DQN балансировать систему тележки с шестом

В этом примере показано, как обучить агента глубокой Q-образовательной-сети (DQN) балансировать систему тележки с шестом, смоделированную в MATLAB®.

Для получения дополнительной информации об агентах DQN смотрите Глубоких Агентов Q-сети. Для примера, который обучает агента DQN в Simulink®, смотрите, Обучают Агента DQN к Swing и Маятнику Баланса.

Среда MATLAB тележки с шестом

Среда обучения с подкреплением для этого примера является полюсом, присоединенным к неприводимому в движение соединению на тележке, которая проходит лишенная трения дорожка. Цель обучения должна заставить полюс стоять вертикально без падения.

Для этой среды:

  • Восходящим сбалансированным выигрышным положением является 0 радианами и нисходящим положением зависания является pi радианы.

  • Полюс запускается вертикально с начального угла между –0.05 и 0,05 радианами.

  • Сигнал действия силы от агента до среды от –10 до 10 Н.

  • Наблюдения средой являются положением и скоростью тележки, угла полюса и угловой производной полюса.

  • Эпизод завершает работу, если полюс является больше чем 12 градусами вертикали или если тележка перемещает больше чем 2,4 м от исходного положения.

  • Вознаграждение +1 предоставлено для каждого временного шага, что полюс остается вертикальным. Штраф –5 применяется, когда полюс падает.

Для получения дополнительной информации об этой модели смотрите Загрузку Предопределенные Среды Системы управления.

Создайте интерфейс среды

Создайте предопределенный интерфейс среды для системы.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

Интерфейс имеет дискретное пространство действий, где агент может применить одно из двух возможных значений силы к тележке, –10 или 10 Н.

Получите информация о спецификации действия и спецификация наблюдений.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

Для повторяемости результатов зафиксируйте начальное значение генератора случайных чисел.

rng(0)

Создайте агента DQN

Агент DQN аппроксимирует долгосрочное вознаграждение, заданные наблюдения и действия, с помощью критика функции ценности.

Агенты DQN могут использовать мультивыходные аппроксимации критика Q-значения, которые обычно более эффективны. Мультивыходная аппроксимация имеет наблюдения как входные параметры и значения состояния активности как выходные параметры. Каждый выходной элемент представляет ожидаемое совокупное долгосрочное вознаграждение за принятие соответствующих дискретных мер от состояния, обозначенного входными параметрами наблюдения.

Чтобы создать критика, сначала создайте глубокую нейронную сеть с одним входом (4-мерное наблюдаемое состояние) и один выходной вектор с двумя элементами (один для действия на 10 Н, другого для действия на-10 Н). Для получения дополнительной информации о создании представлений функции ценности на основе нейронной сети смотрите, Создают Представления Функции ценности и политика.

dnn = [
    featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(length(actInfo.Elements),'Name','output')];

Просмотрите конфигурацию сети.

figure
plot(layerGraph(dnn))

Figure contains an axes. The axes contains an object of type graphplot.

Задайте некоторые опции обучения для представления критика с помощью rlRepresentationOptions.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);

Создайте представление критика с помощью заданной нейронной сети и опций. Для получения дополнительной информации смотрите rlQValueRepresentation.

critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);

Чтобы создать агента DQN, сначала задайте опции агента DQN с помощью rlDQNAgentOptions.

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetSmoothFactor',1, ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',256);

Затем создайте агента DQN с помощью заданного представления критика и опций агента. Для получения дополнительной информации смотрите rlDQNAgent.

agent = rlDQNAgent(critic,agentOpts);

Обучите агента

Чтобы обучить агента, сначала задайте опции обучения. В данном примере используйте следующие опции:

  • Запустите один сеанс обучения, содержащий самое большее 1 000 эпизодов с каждым эпизодом, длящимся самое большее 500 временных шагов.

  • Отобразите прогресс обучения в диалоговом окне Episode Manager (установите Plots опция), и отключают отображение командной строки (установите Verbose опция к false).

  • Остановите обучение, когда агент получит скользящее среднее значение совокупное вознаграждение, больше, чем 480. На данном этапе агент может сбалансировать систему тележки с шестом в вертикальном положении.

Для получения дополнительной информации смотрите rlTrainingOptions.

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode',500, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',480); 

Можно визуализировать систему тележки с шестом, может визуализироваться при помощи plot функция во время обучения или симуляции.

plot(env)

Figure Cart Pole Visualizer contains an axes. The axes contains 6 objects of type line, polygon.

Обучите агента с помощью train функция. Обучение этот агент является в вычислительном отношении интенсивным процессом, который занимает несколько минут, чтобы завершиться. Чтобы сэкономить время при выполнении этого примера, загрузите предварительно обученного агента установкой doTraining к false. Чтобы обучить агента самостоятельно, установите doTraining к true.

doTraining = false;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('MATLABCartpoleDQNMulti.mat','agent')
end

Симулируйте агента DQN

Чтобы подтвердить производительность обученного агента, симулируйте его в среде тележки с шестом. Для получения дополнительной информации о симуляции агента смотрите rlSimulationOptions и sim. Агент может сбалансировать тележку с шестом, даже когда время симуляции увеличивается до 500 шагов.

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes. The axes contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 500

Смотрите также

Похожие темы