Обучите агента DQN балансировать систему тележки с шестом

Этот пример показывает, как обучить агента глубокой сети Q-обучения (DQN) для балансировки системы тележки с шестом, смоделированной в MATLAB ®.

Дополнительные сведения об агентах DQN см. в разделе Агенты глубоких Q-сетей. Для примера, который обучает агента DQN в Simulink ®, смотрите Train DQN Agent to Swing Up и Balance Mendulum.

Тележка с шестом MATLAB Окружения

Окружение обучения с подкреплением для этого примера является шестом, прикрепленным к неактуированному соединению на тележке, которое перемещается по безфрикционной дорожке. Цель обучения состоит в том, чтобы сделать шест вертикальным, не опускаясь.

Для этого окружения:

  • Положение сбалансированного полюса вверх 0 радианы, и положение свисания вниз pi радианы.

  • Шест запускается вертикально с начального угла между -0,05 и 0.05 радианами.

  • Сигнал действия силы от агента к окружению от -10 до 10 Н.

  • Наблюдения от окружения являются положением и скоростью тележки, углом шеста и производной угла полюса.

  • Эпизод прекращается, если полюс находится более чем в 12 степенях от вертикали или если тележка перемещается более чем на 2,4 м от исходного положения.

  • Вознаграждение + 1 предоставляется за каждый временной шаг, когда шест остается вертикальным. Штраф -5 применяется при падении шеста.

Для получения дополнительной информации об этой модели смотрите Загрузка предопределённых окружений системы управления.

Создайте интерфейс окружения

Создайте предопределенный интерфейс окружения для системы.

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

Интерфейс имеет дискретное пространство действий, где агент может применить одно из двух возможных значений силы к тележке, -10 или 10 Н.

Получите информацию о наблюдении и спецификации действия.

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

Исправьте начальное значение генератора для повторяемости.

rng(0)

Создание агента DQN

Агент DQN аппроксимирует долгосрочное вознаграждение, заданные наблюдения и действия, используя критика функции ценности.

Агенты DQN могут использовать аппроксимации критика Q-значений с мультивыходами, которые, как правило, более эффективны. Аппроксимация мультивыхода имеет наблюдения в качестве входов и состояния активности значения в качестве выходов. Каждый выходной элемент представляет ожидаемое совокупное долгосрочное вознаграждение для принятия соответствующего дискретного действия от состояния, обозначенного входами наблюдения.

Чтобы создать критика, сначала создайте глубокую нейронную сеть с одним входом (4-мерное наблюдаемое состояние) и одним выходным вектором с двумя элементами (один для действия 10 N, другой для действия -10 N). Для получения дополнительной информации о создании представлений функции ценности на основе нейронной сети, смотрите, Создают Политику и Представления Функции Ценности.

dnn = [
    featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(length(actInfo.Elements),'Name','output')];

Просмотрите сетевое строение.

figure
plot(layerGraph(dnn))

Figure contains an axes. The axes contains an object of type graphplot.

Задайте некоторые опции обучения для представления критика, используя rlRepresentationOptions.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);

Создайте представление критика с помощью заданной нейронной сети и опций. Для получения дополнительной информации смотрите rlQValueRepresentation.

critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);

Чтобы создать агента DQN, сначала задайте опции агента DQN с помощью rlDQNAgentOptions.

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetSmoothFactor',1, ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',256);

Затем создайте агента DQN с помощью заданного представления критика и опций агента. Для получения дополнительной информации смотрите rlDQNAgent.

agent = rlDQNAgent(critic,agentOpts);

Обучите агента

Чтобы обучить агента, сначала укажите опции обучения. В данном примере используйте следующие опции:

  • Запустите одну обучающий сеанс, содержащую самое большее 1000 эпизодов с каждым эпизодом, длящимся самое большее 500 временных шагов.

  • Отображение процесса обучения в диалоговом окне Диспетчер эпизодов (установите Plots опция) и отключить отображение командной строки (установите Verbose опция для false).

  • Остановите обучение, когда агент получит совокупное вознаграждение скользящего среднего значения, больше 480. На данной точке агент может сбалансировать систему тележки с шестом в вертикальном положении.

Для получения дополнительной информации смотрите rlTrainingOptions.

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode',500, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',480); 

Можно визуализировать систему тележки с шестом, которую можно визуализировать при помощи plot функция во время обучения или симуляции.

plot(env)

Figure Cart Pole Visualizer contains an axes. The axes contains 6 objects of type line, polygon.

Обучите агента с помощью train функция. Обучение этого агента является интенсивным в вычислительном отношении процессом, который занимает несколько минут. Чтобы сэкономить время при запуске этого примера, загрузите предварительно обученного агента путем установки doTraining на false. Чтобы обучить агента самостоятельно, установите doTraining на true.

doTraining = false;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('MATLABCartpoleDQNMulti.mat','agent')
end

Симулируйте агент DQN

Чтобы подтвердить производительность обученного агента, симулируйте его в среде тележки с шестом. Для получения дополнительной информации о симуляции агента смотрите rlSimulationOptions и sim. Агент может сбалансировать тележку с шестом, даже когда время симуляции увеличивается до 500 шагов.

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

Figure Cart Pole Visualizer contains an axes. The axes contains 6 objects of type line, polygon.

totalReward = sum(experience.Reward)
totalReward = 500

См. также

Похожие темы