Создайте агента Используя Deep Network Designer и обучайтесь Используя наблюдения изображений

В этом примере показано, как создать агента глубокой Q-образовательной-сети (DQN), который может качаться и сбалансировать маятник, смоделированный в MATLAB®. В этом примере вы создаете агента DQN с помощью Deep Network Designer. Для получения дополнительной информации об агентах DQN смотрите Глубоких Агентов Q-сети (Reinforcement Learning Toolbox).

Маятник Swing с изображением среда MATLAB

Среда обучения с подкреплением для этого примера является простым лишенным трения маятником, который первоначально висит в нисходящем положении. Цель обучения должна заставить маятник стоять вертикально, не падая и используя минимальные усилия по управлению.

Для этой среды:

  • Восходящим сбалансированным положением маятника является 0 радианами и нисходящим положением зависания является pi радианы.

  • Сигнал действия крутящего момента от агента до среды от –2 до 2 Н · m.

  • Наблюдения средой являются упрощенным полутоновым изображением маятника и угловой производной маятника.

  • Вознаграждение rt, если на каждом временном шаге,

rt=-(θt2+0.1θt˙2+0.001ut-12)

Здесь:

  • θt угол смещения от вертикального положения.

  • θt˙ производная угла рассогласования.

  • ut-1 усилие по управлению от предыдущего временного шага.

Для получения дополнительной информации об этой модели смотрите, Обучают Агента DDPG к Swing и Маятнику Баланса с Наблюдением Изображений (Reinforcement Learning Toolbox).

Создайте интерфейс среды

Создайте предопределенный интерфейс среды для маятника.

env = rlPredefinedEnv('SimplePendulumWithImage-Discrete');

Интерфейс имеет два наблюдения. Первое наблюдение, названное "pendImage", 50 50 полутоновое изображение.

obsInfo = getObservationInfo(env);
obsInfo(1)
ans = 
  rlNumericSpec with properties:

     LowerLimit: 0
     UpperLimit: 1
           Name: "pendImage"
    Description: [0x0 string]
      Dimension: [50 50]
       DataType: "double"

Второе наблюдение, названное "angularRate", скорость вращения маятника.

obsInfo(2)
ans = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "angularRate"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

Интерфейс имеет дискретное пространство действий, где агент может применить одно из пяти возможных значений крутящего момента к маятнику: -2, -1, 0, 1, или 2 N· m.

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-2 -1 0 1 2]
           Name: "torque"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

Для повторяемости результатов зафиксируйте начальное значение генератора случайных чисел.

rng(0)

Создайте сеть критика Используя Deep Network Designer

Агент DQN аппроксимирует долгосрочное вознаграждение, заданные наблюдения и действия, с помощью представления функции ценности критика. Для этой среды критик является глубокой нейронной сетью с тремя входными параметрами (два наблюдения и одно действие), и один выход. Для получения дополнительной информации о создании представления функции ценности глубокой нейронной сети смотрите, Создают политику и Представления Функции ценности (Reinforcement Learning Toolbox).

Можно создать сеть критика в интерактивном режиме при помощи приложения Deep Network Designer. Для этого вы сначала создаете отдельные входные пути для каждого наблюдения и действия. Эти пути узнают о более низких функциях уровня из своих соответствующих входных параметров. Вы затем создаете общий выход path, который комбинирует выходные параметры от входных путей.

Создайте путь к наблюдению изображений

Чтобы создать путь к наблюдению изображений, сначала перетащите ImageInputLayer от Слоя Библиотека разделяют на области к холсту. Установите слой InputSize на 50,50,1 для наблюдения изображений и Нормализации набора к none.

Во-вторых, перетащите Convolution2DLayer к холсту и подключению вход этого слоя к выходу ImageInputLayer. Создайте слой свертки с 2 фильтры (свойство NumFilters), которые имеют высоту и ширину 10 (Свойство FilterSize), и использование шаг 5 в горизонтальных и вертикальных направлениях (Свойство Stride).

Наконец, завершите сеть канала передачи изображения с двумя наборами ReLULayer и FullyConnectedLayer слои. Выходные размеры первого и второго FullyConnectedLayer слои 400 и 300, соответственно.

Создайте все входные пути и Выход Path

Создайте другие входные пути и выход path подобным образом. В данном примере используйте следующие опции.

Путь к скорости вращения (скалярный вход):

  • ImageInputLayer — Установите InputSize на 1,1 и нормализация к none.

  • FullyConnectedLayer — Установите OutputSize на 400.

  • ReLULayer

  • FullyConnectedLayer — Установите OutputSize на 300.

Путь к действию (скалярный вход):

  • ImageInputLayer — Установите InputSize на 1,1 и нормализация к none.

  • FullyConnectedLayer — Установите OutputSize на 300.

Выход path:

  • AdditionLayer — Соедините выход всех входных путей к входу этого слоя.

  • ReLULayer

  • FullyConnectedLayer — Установите OutputSize на 1 для функции скалярного значения.

Сеть экспорта от Deep Network Designer

Чтобы экспортировать сеть в рабочее пространство MATLAB, в Deep Network Designer, нажимают Export. Deep Network Designer экспортирует сеть как новую переменную, содержащую слоя сети. Можно создать представление критика с помощью этой переменной сети слоя.

В качестве альтернативы, чтобы сгенерировать эквивалентный код MATLAB для сети, нажмите Export> Generate Code.

Сгенерированный код следующие.

lgraph = layerGraph();
layers = [
    imageInputLayer([1 1 1],"Name","torque","Normalization","none")
    fullyConnectedLayer(300,"Name","torque_fc1")];
lgraph = addLayers(lgraph,layers);
layers = [
    imageInputLayer([1 1 1],"Name","angularRate","Normalization","none")
    fullyConnectedLayer(400,"Name","dtheta_fc1")
    reluLayer("Name","dtheta_relu1")
    fullyConnectedLayer(300,"Name","dtheta_fc2")];
lgraph = addLayers(lgraph,layers);
layers = [
    imageInputLayer([50 50 1],"Name","pendImage","Normalization","none")
    convolution2dLayer([10 10],2,"Name","img_conv1","Stride",[5 5])
    reluLayer("Name","img_relu")
    fullyConnectedLayer(400,"Name","theta_fc1")
    reluLayer("Name","theta_relu1")
    fullyConnectedLayer(300,"Name","theta_fc2")];
lgraph = addLayers(lgraph,layers);
layers = [
    additionLayer(3,"Name","addition")
    reluLayer("Name","relu")
    fullyConnectedLayer(1,"Name","stateValue")];
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,"torque_fc1","addition/in3");
lgraph = connectLayers(lgraph,"theta_fc2","addition/in1");
lgraph = connectLayers(lgraph,"dtheta_fc2","addition/in2");

Просмотрите конфигурацию сети критика.

figure
plot(lgraph)

Задайте опции для представления критика с помощью rlRepresentationOptions (Reinforcement Learning Toolbox).

criticOpts = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);

Создайте представление критика с помощью заданной глубокой нейронной сети lgraph и опции. Необходимо также задать информацию о действии и наблюдении для критика, которого вы получаете из интерфейса среды. Для получения дополнительной информации смотрите rlQValueRepresentation (Reinforcement Learning Toolbox).

critic = rlQValueRepresentation(lgraph,obsInfo,actInfo,...
    'Observation',{'pendImage','angularRate'},'Action',{'torque'},criticOpts);

Чтобы создать агента DQN, сначала задайте опции агента DQN с помощью rlDQNAgentOptions (Reinforcement Learning Toolbox).

agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false,...    
    'TargetUpdateMethod',"smoothing",...
    'TargetSmoothFactor',1e-3,... 
    'ExperienceBufferLength',1e6,... 
    'DiscountFactor',0.99,...
    'SampleTime',env.Ts,...
    'MiniBatchSize',64);
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-5;

Затем создайте агента DQN с помощью заданного представления критика и опций агента. Для получения дополнительной информации смотрите rlDQNAgent (Reinforcement Learning Toolbox).

agent = rlDQNAgent(critic,agentOpts);

Обучите агента

Чтобы обучить агента, сначала задайте опции обучения. В данном примере используйте следующие опции.

  • Запустите каждое обучение самое большее 5 000 эпизодов с каждым эпизодом, длящимся самое большее 500 временных шагов.

  • Отобразите прогресс обучения в диалоговом окне Episode Manager (установите Plots опция), и отключают отображение командной строки (установите Verbose опция к false).

  • Остановите обучение, когда агент получит среднее совокупное вознаграждение, больше, чем –1000 по длине окна по умолчанию пяти последовательных эпизодов. На данном этапе агент может быстро сбалансировать маятник в вертикальном положении с помощью минимального усилия по управлению.

Для получения дополнительной информации смотрите rlTrainingOptions (Reinforcement Learning Toolbox).

trainOpts = rlTrainingOptions(...
    'MaxEpisodes',5000,...
    'MaxStepsPerEpisode',500,...
    'Verbose',false,...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',-1000);

Можно визуализировать систему маятника во время обучения или симуляции при помощи plot функция.

plot(env)

Обучите агента с помощью train (Reinforcement Learning Toolbox) функция. Это - в вычислительном отношении интенсивный процесс, который занимает несколько часов, чтобы завершиться. Чтобы сэкономить время при выполнении этого примера, загрузите предварительно обученного агента установкой doTraining к false. Чтобы обучить агента самостоятельно, установите doTraining к true.

doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load pretrained agent for the example.
    load('MATLABPendImageDQN.mat','agent');
end

Симулируйте агента DQN

Чтобы подтвердить производительность обученного агента, симулируйте его в среде маятника. Для получения дополнительной информации о симуляции агента смотрите rlSimulationOptions (Reinforcement Learning Toolbox) и sim (Reinforcement Learning Toolbox).

simOptions = rlSimulationOptions('MaxSteps',500);
experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward)
totalReward = -888.9802

Смотрите также

| (Reinforcement Learning Toolbox)

Похожие темы

Для просмотра документации необходимо авторизоваться на сайте