Этот пример показывает, как обучить агент глубокой Q-образовательной-сети (DQN) балансировать полюсную корзиной систему, смоделированную в MATLAB®.
Для получения дополнительной информации об агентах DQN смотрите Глубокие Агенты Q-сети. Для примера, который обучает агент DQN в Simulink®, смотрите Train Агент DQN к Swing и Маятнику Баланса.
Среда обучения укрепления для этого примера является полюсом, присоединенным к неприводимому в действие соединению на корзине, которая проходит лишенная трения дорожка. Учебная цель состоит в том, чтобы заставить маятник стоять вертикально без падения.
Для этой среды:
Восходящее сбалансированное положение маятника является радианами 0
, и нисходящее положение зависания является радианами pi
Маятник запускается вертикально с начального угла + радианы/-0.05
Сигнал действия силы от агента до среды от-10 до 10 Н
Наблюдения от среды являются положением и скоростью корзины, угла маятника и его производной
Эпизод останавливается, если полюс является больше чем 12 градусами вертикали, или корзина перемещает больше чем 2,4 м от исходного положения
Вознаграждение +1 предоставлено для каждого временного шага, что полюс остается вертикальным. Штраф-5 применяется, когда маятник падает.
Для получения дополнительной информации об этой модели смотрите Загрузку Предопределенные Среды Системы управления.
Создайте предопределенный интерфейс среды для маятника.
env = rlPredefinedEnv("CartPole-Discrete")
env = CartPoleDiscreteAction with properties: Gravity: 9.8000 MassCart: 1 MassPole: 0.1000 Length: 0.5000 MaxForce: 10 Ts: 0.0200 ThetaThresholdRadians: 0.2094 XThreshold: 2.4000 RewardForNotFalling: 1 PenaltyForFalling: -5 State: [4×1 double]
Интерфейс имеет дискретный пробел действия, где агент может применить одно из двух возможных значений силы к корзине,-10 или 10 Н.
Зафиксируйте случайный seed генератора для воспроизводимости.
rng(0);
Агент DQN аппроксимирует долгосрочное вознаграждение, данное наблюдения и действия с помощью представления функции значения критика. Чтобы создать критика, сначала создайте глубокую нейронную сеть с двумя входными параметрами, состоянием и действием и одним выводом. Для получения дополнительной информации о создании представления функции значения нейронной сети смотрите, Создают политику и Представления Функции Значения.
statePath = [ imageInputLayer([4 1 1], 'Normalization', 'none', 'Name', 'state') fullyConnectedLayer(24, 'Name', 'CriticStateFC1') reluLayer('Name', 'CriticRelu1') fullyConnectedLayer(24, 'Name', 'CriticStateFC2')]; actionPath = [ imageInputLayer([1 1 1], 'Normalization', 'none', 'Name', 'action') fullyConnectedLayer(24, 'Name', 'CriticActionFC1')]; commonPath = [ additionLayer(2,'Name', 'add') reluLayer('Name','CriticCommonRelu') fullyConnectedLayer(1, 'Name', 'output')]; criticNetwork = layerGraph(statePath); criticNetwork = addLayers(criticNetwork, actionPath); criticNetwork = addLayers(criticNetwork, commonPath); criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1'); criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');
Просмотрите конфигурацию сети критика.
figure plot(criticNetwork)
Задайте опции для представления критика с помощью rlRepresentationOptions
.
criticOpts = rlRepresentationOptions('LearnRate',0.01,'GradientThreshold',1);
Создайте представление критика с помощью заданной нейронной сети и опций. Необходимо также задать информацию о действии и наблюдении для критика, которого вы получаете из интерфейса среды. Для получения дополнительной информации смотрите rlRepresentation
.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env); critic = rlRepresentation(criticNetwork,obsInfo,actInfo,'Observation',{'state'},'Action',{'action'},criticOpts);
Чтобы создать агент DQN, сначала задайте опции агента DQN с помощью rlDQNAgentOptions
.
agentOpts = rlDQNAgentOptions(... 'UseDoubleDQN',false, ... 'TargetUpdateMethod',"periodic", ... 'TargetUpdateFrequency',4, ... 'ExperienceBufferLength',100000, ... 'DiscountFactor',0.99, ... 'MiniBatchSize',256);
Затем создайте агент DQN с помощью заданного представления критика и опций агента. Для получения дополнительной информации смотрите rlDQNAgent
.
agent = rlDQNAgent(critic,agentOpts);
Чтобы обучить агент, сначала задайте опции обучения. В данном примере используйте следующие опции:
Запустите каждый учебный эпизод для самое большее 1 000 эпизодов с каждым эпизодом, длящимся самое большее 200 временных шагов.
Отобразитесь учебный прогресс диалогового окна Episode Manager (установите опцию Plots
), и отключите отображение командной строки (установите опцию Verbose
).
Остановите обучение, когда агент получит среднее совокупное вознаграждение, больше, чем 195 более чем 10 последовательных эпизодов. На данном этапе агент может сбалансировать маятник в вертикальном положении.
Для получения дополнительной информации смотрите rlTrainingOptions
.
trainOpts = rlTrainingOptions(... 'MaxEpisodes', 1000, ... 'MaxStepsPerEpisode', 500, ... 'Verbose', false, ... 'Plots','training-progress',... 'StopTrainingCriteria','AverageReward',... 'StopTrainingValue',480);
Полюсная корзиной система может визуализироваться с использованием функции plot
во время обучения или симуляции.
plot(env);
Обучите агент с помощью функции train
. Это - в вычислительном отношении интенсивный процесс, который занимает несколько минут, чтобы завершиться. Чтобы сэкономить время при выполнении этого примера, загрузите предварительно обученный агент установкой doTraining
к false
. Чтобы обучить агент самостоятельно, установите doTraining
на true
.
doTraining = false; if doTraining % Train the agent. trainingStats = train(agent,env,trainOpts); else % Load pretrained agent for the example. load('MATLABCartpoleDQN.mat','agent'); end
Чтобы подтвердить производительность обученного агента, моделируйте его в полюсной корзиной среде. Для получения дополнительной информации о симуляции агента смотрите rlSimulationOptions
и sim
. Агент может сбалансировать полюсное корзиной, даже когда время симуляции увеличивается до 500.
simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward)
totalReward = 500
MATLAB и Simulink являются зарегистрированными торговыми марками MathWorks, Inc. См. www.mathworks.com/trademarks для списка других товарных знаков, принадлежавших MathWorks, Inc. Другим продуктом или фирменными знаками являются товарные знаки или зарегистрированные торговые марки их соответствующих владельцев.