Основанное на модели обучение с подкреплением Используя пользовательский учебный цикл

В этом примере показано, как задать пользовательский учебный цикл для алгоритма основанного на модели обучения с подкреплением (MBRL). Можно использовать этот рабочий процесс, чтобы обучить политику MBRL с пользовательским алгоритмом настройки с помощью политики и представлений функции ценности от пакета Reinforcement Learning Toolbox™.

В этом примере вы используете модели перехода, чтобы сгенерировать больше событий в то время как обучение пользовательский агент DQN [2] в среде тележки с шестом. Алгоритм, используемый в этом примере, основан на основанном на модели алгоритме оптимизации политики (MBPO) [1]. Исходный алгоритм MBPO обучает ансамбль стохастических моделей и агента SAC в задачах с непрерывными действиями. В отличие от этого этот пример обучает ансамбль детерминированных моделей и агента DQN в задаче с дискретными действиями. Следующая фигура обобщает алгоритм, используемый в этом примере.

Агент генерирует действительные события путем взаимодействия со средой. Эти события используются, чтобы обучить набор моделей перехода, которые используются, чтобы сгенерировать дополнительные события. Алгоритм настройки затем использует обоих действительные и сгенерированные события обновить политику агента.

Создайте среду

В данном примере политика обучения с подкреплением обучена в дискретной среде тележки с шестом. Цель в этой среде состоит в том, чтобы сбалансировать полюс, прикладывая силы (действия) с тележкой. Создайте среду с помощью rlPredefinedEnv функция. Для повторяемости результатов зафиксируйте начальное значение генератора случайных чисел. Для получения дополнительной информации об этой среде смотрите Загрузку Предопределенные Среды Системы управления.

clear
clc
rngSeed = 1;
rng(rngSeed);
env = rlPredefinedEnv('CartPole-Discrete');

Извлеките спецификации наблюдений и спецификации действия от среды.

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Получите количество наблюдений (numObservations) и действия (numActions).

numObservations = obsInfo.Dimension(1);
numActions = numel(actInfo.Elements); % number of discrete actions, -10 or 10
numContinuousActions = 1; % force

Конструкция критика

DQN является основанным на значении алгоритмом обучения с подкреплением, который оценивает обесцененное совокупное вознаграждение с помощью критика. В этом примере сеть критика содержит fullyConnectedLayer, и reluLayer слои.

qNetwork = [
    featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(length(actInfo.Elements),'Name','output')];
qNetwork = dlnetwork(qNetwork);

Создайте представление критика с помощью заданной нейронной сети и опций. Для получения дополнительной информации смотрите rlQValueRepresentation.

criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
critic = rlQValueRepresentation(qNetwork,obsInfo,actInfo,'Observation',{'state'},criticOpts);

Создайте модели перехода

Основанное на модели обучение с подкреплением использует модели перехода среды. Модель обычно состоит из функции перехода, премиальной функции и функции конечного состояния.

  • Функция перехода предсказывает следующее наблюдение, учитывая текущее наблюдение и действие.

  • Премиальная функция предсказывает вознаграждение, учитывая текущее наблюдение, действие и следующее наблюдение.

  • Функция конечного состояния предсказывает конечное состояние, учитывая наблюдение.

Как показано в следующем рисунке этот пример использует три функции перехода в качестве ансамбля моделей перехода, чтобы сгенерировать выборки, не взаимодействуя со средой. Истинная премиальная функция и истинная функция конечного состояния даны в этом примере.

Задайте три нейронных сети для моделей перехода. Нейронная сеть предсказывает различие между следующим наблюдением и текущим наблюдением.

numModels = 3;
transitionNetwork1 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetwork2 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetwork3 = createTransitionNetwork(numObservations, numContinuousActions);
transitionNetworkVector = [transitionNetwork1, transitionNetwork2, transitionNetwork3];

Создайте буферы опыта

Создайте буфер опыта для хранения событий агента (наблюдение, действие, следующее наблюдение, вознаграждение и isDone).

myBuffer.bufferSize = 1e5;
myBuffer.bufferIndex = 0;
myBuffer.currentBufferLength = 0;
myBuffer.observation = zeros(numObservations,myBuffer.bufferSize);
myBuffer.nextObservation = zeros(numObservations,myBuffer.bufferSize);
myBuffer.action = zeros(numContinuousActions,1,myBuffer.bufferSize);
myBuffer.reward = zeros(1,myBuffer.bufferSize);
myBuffer.isDone = zeros(1,myBuffer.bufferSize);

Создайте буфер опыта модели для хранения событий, сгенерированных моделями.

myModelBuffer.bufferSize = 1e5;
myModelBuffer.bufferIndex = 0;
myModelBuffer.currentBufferLength = 0;
myModelBuffer.observation = zeros(numObservations,myModelBuffer.bufferSize);
myModelBuffer.nextObservation = zeros(numObservations,myModelBuffer.bufferSize);
myModelBuffer.action = zeros(numContinuousActions,myModelBuffer.bufferSize);
myModelBuffer.reward = zeros(1,myModelBuffer.bufferSize);
myModelBuffer.isDone = zeros(1,myModelBuffer.bufferSize);

Сконфигурируйте обучение

Сконфигурируйте обучение использовать следующие опции.

  • Максимальное количество эпизодов тренировки — 250

  • Максимальные шаги на эпизод тренировки — 500

  • Коэффициент дисконтирования — 0.99

  • Учебное условие завершения — Среднее вознаграждение через 10 эпизодов достигает значения 480

numEpisodes = 250;
maxStepsPerEpisode = 500;
discountFactor = 0.99;
aveWindowSize = 10;
trainingTerminationValue = 480;

Сконфигурируйте опции модели.

  • Обучите модели перехода только после того, как 2 000 выборок будут собраны.

  • Обучите модели с помощью всех событий в действительном буфере опыта в каждом эпизоде. Используйте мини-пакетный размер 256.

  • Модели генерируют траектории с длиной 4 в начале каждого эпизода.

  • Количеством сгенерированных траекторий является numGenerateSampleIteration x numModels x miniBatchSize = 20 x 3 x 256 = 15360.

  • Используйте те же эпсилон-жадные параметры в качестве агента DQN, за исключением минимального значения эпсилона.

  • Используйте минимальное значение эпсилона 0,1, который выше, чем значение, используемое для взаимодействия со средой. Выполнение так позволяет модели генерировать более разнообразные данные.

warmStartSamples = 2000;
numEpochs = 1;
miniBatchSize = 256;
horizonLength = 4;
epsilonMinModel = 0.1;
numGenerateSampleIteration = 20;
sampleGenerationOptions.horizonLength = horizonLength;
sampleGenerationOptions.numGenerateSampleIteration = numGenerateSampleIteration;
sampleGenerationOptions.miniBatchSize = miniBatchSize;
sampleGenerationOptions.numObservations = numObservations;
sampleGenerationOptions.epsilonMinModel = epsilonMinModel;

% optimizer options
velocity1 = [];
velocity2 = [];
velocity3 = [];
decay = 0.01;
momentum = 0.9;
learnRate = 0.0005;

Сконфигурируйте опции обучения DQN.

  • EUse эпсилон жадный алгоритм с начальным значением эпсилона равняется 1, минимальному значению 0,01 и уровню затухания 0,005.

  • Обновитесь цель объединяют каждые 4 шага в сеть.

  • Установите отношение действительных событий к сгенерированным событиям к 0.2:0.8 установкой RealRatio к 0,2. Установка RealRatio к 1,0 совпадает с DQN без моделей.

  • Сделайте 5 шагов градиента на каждом шаге среды.

epsilon = 1;
epsilonMin = 0.01;
epsilonDecay = 0.005;
targetUpdateFrequency = 4;
realRatio = 0.2; % Set to 1 to run a standard DQN
numGradientSteps = 5;

Создайте вектор для хранения совокупного вознаграждения за каждый эпизод тренировки.

episodeCumulativeRewardVector = [];

Создайте фигуру для учебной визуализации модели с помощью hBuildFigureModel функция помощника.

[trainingPlotModel, lineLossTrain1, lineLossTrain2, lineLossTrain3, axModel] = hBuildFigureModel();

Создайте фигуру для визуализации проверки допустимости модели с помощью hBuildFigureModelTest функция помощника.

[testPlotModel, lineLossTest1, axModelTest] = hBuildFigureModelTest();

Создайте фигуру для визуализации обучения агента DQN с помощью hBuildFigure функция помощника.

[trainingPlot,lineReward,lineAveReward, ax] = hBuildFigure;

Обучите агента

Обучите агента с помощью пользовательского учебного цикла. Учебный цикл использует следующий алгоритм. Для каждого эпизода:

  1. Обучите модели перехода.

  2. Сгенерируйте события с помощью моделей перехода и сохраните выборки в буфере опыта модели.

  3. Сгенерируйте действительный опыт. Для этого сгенерируйте действие с помощью политики, примените действие к среде и получите получившееся наблюдение, вознаграждение, и - сделанные значения.

  4. Создайте мини-пакет путем выборки событий и от буфера опыта и от буфера опыта модели.

  5. Вычислите цель Q значение.

  6. Вычислите градиент функции потерь относительно параметров представления критика.

  7. Обновите представление критика с помощью вычисленных градиентов.

  8. Обновите учебную визуализацию.

  9. Отключите обучение, если критик достаточно обучен.

targetCritic = critic;
modelTrainedAtleastOnce = false;
totalStepCt = 0;
start = tic;

set(trainingPlotModel,'Visible','on');
set(testPlotModel,'Visible','on');
set(trainingPlot,'Visible','on');

for episodeCt = 1:numEpisodes
    if myBuffer.currentBufferLength > miniBatchSize && ...
            totalStepCt > warmStartSamples
        if realRatio < 1.0
            %----------------------------------------------
            % 1. Train transition models.
            %----------------------------------------------
            % Training three transition models
            [transitionNetworkVector(1),loss1,velocity1] = ...
                trainTransitionModel(transitionNetworkVector(1),...
                myBuffer,velocity1,miniBatchSize,...
                numEpochs,momentum,learnRate);
            [transitionNetworkVector(2),loss2,velocity2] = ...
                trainTransitionModel(transitionNetworkVector(2),...
                myBuffer,velocity2,miniBatchSize,...
                numEpochs,momentum,learnRate);
            [transitionNetworkVector(3),loss3,velocity3] = ...
                trainTransitionModel(transitionNetworkVector(3),...
                myBuffer,velocity3,miniBatchSize,...
                numEpochs,momentum,learnRate);
            modelTrainedAtleastOnce = true;

            % Display the training progress
            d = duration(0,0,toc(start),'Format','hh:mm:ss');
            addpoints(lineLossTrain1,episodeCt,loss1)
            addpoints(lineLossTrain2,episodeCt,loss2)
            addpoints(lineLossTrain3,episodeCt,loss3)
            legend(axModel,'Model1','Model2','Model3');
            title(axModel,"Model Training Progress - Episode: "...
                + episodeCt + ", Elapsed: " + string(d))
            drawnow

            %----------------------------------------------
            % 2. Generate experience using models.
            %----------------------------------------------
            % Create numGenerateSampleIteration x horizonLength x numModels x miniBatchSize
            % ex) 20 x 4 x 3 x 256 = 61440 samples            
            myModelBuffer = generateSamples(myBuffer,myModelBuffer,...
                transitionNetworkVector,critic,actInfo,...
                epsilon,sampleGenerationOptions);
        end
    end

    %----------------------------------------------
    % Interact with environment and train agent.
    %----------------------------------------------
    % Reset the environment at the start of the episode
    observation = reset(env);
    episodeReward = zeros(maxStepsPerEpisode,1);
    errorPreddiction = zeros(maxStepsPerEpisode,1);

    for stepCt = 1:maxStepsPerEpisode
        %----------------------------------------------
        % 3. Generate an experience.
        %----------------------------------------------
        totalStepCt = totalStepCt + 1;

        % Compute an action using the policy based on the current observation.
        if rand() < epsilon
            action = actInfo.usample;
            action = action{1};
        else
            action = getAction(critic,{observation});
        end
        % Udpate epsilon
        if totalStepCt > warmStartSamples
            epsilon = max(epsilon*(1-epsilonDecay),epsilonMin);
        end

        % Apply the action to the environment and obtain the resulting
        % observation and reward.
        [nextObservation,reward,isDone] = step(env,action);

        % Check prediction
        dx = predict(transitionNetworkVector(1),...
            dlarray(observation,'CB'),dlarray(action,'CB'));
        predictedNextObservation = observation + dx;
        errorPreddiction(stepCt) = ...
            sqrt(sum((nextObservation - predictedNextObservation).^2));

        % Store the action, observation, reward and is-done experience  
        myBuffer = storeExperience(myBuffer,observation,action,...
            nextObservation,reward,isDone);

        episodeReward(stepCt) = reward;
        observation = nextObservation;

        % Train DQN agent
        for gradientCt = 1:numGradientSteps
            if myBuffer.currentBufferLength >= miniBatchSize && ...
                    totalStepCt>warmStartSamples
               %-----------------------------------------------------
               % 4. Sample minibatch from experience buffers.
               %-----------------------------------------------------
               [sampledObservation,sampledAction,sampledNextObservation,sampledReward,sampledIsdone] = ...
                   sampleMinibatch(modelTrainedAtleastOnce,realRatio,miniBatchSize,myBuffer,myModelBuffer); 

               %-----------------------------------------------------
               % 5. Compute target Q value.
               %-----------------------------------------------------
                % Compute target Q value
                [targetQValues, MaxActionIndices] = getMaxQValue(targetCritic, ...
                    {reshape(sampledNextObservation,[numObservations,1,miniBatchSize])});

                % Compute target for nonterminal states
                targetQValues(~logical(sampledIsdone)) = sampledReward(~logical(sampledIsdone)) + ...
                    discountFactor.*targetQValues(~logical(sampledIsdone));
                % Compute target for terminal states
                targetQValues(logical(sampledIsdone)) = sampledReward(logical(sampledIsdone));

                lossData.batchSize = miniBatchSize;
                lossData.actInfo = actInfo;
                lossData.actionBatch = sampledAction;
                lossData.targetQValues = targetQValues;

               %-----------------------------------------------------
               % 6. Compute gradients.
               %-----------------------------------------------------
                criticGradient = gradient(critic,@criticLossFunction, ...
                    {reshape(sampledObservation,[numObservations,1,miniBatchSize])},lossData);

                %----------------------------------------------------
                % 7. Update the critic network using gradients.
                %----------------------------------------------------
                critic = optimize(critic,criticGradient);
            end
        end
        % Update target critic periodically
        if mod(totalStepCt, targetUpdateFrequency)==0
            targetCritic = critic;
        end

        % Stop if a terminal condition is reached.
        if isDone
            break;
        end
    end % End of episode

    %----------------------------------------------------------------
    % 8. Update the training visualization.
    %----------------------------------------------------------------
    episodeCumulativeReward = sum(episodeReward);
    episodeCumulativeRewardVector = cat(2,...
        episodeCumulativeRewardVector,episodeCumulativeReward);
    movingAveReward = movmean(episodeCumulativeRewardVector,...
        aveWindowSize,2);
    addpoints(lineReward,episodeCt,episodeCumulativeReward);
    addpoints(lineAveReward,episodeCt,movingAveReward(end));
    title(ax, "Training Progress - Episode: " + episodeCt + ...
        ", Total Step: " + string(totalStepCt) + ...
        ", epsilon:" + string(epsilon))
    drawnow;

    errorPreddiction = errorPreddiction(1:stepCt);

    % Display one step prediction error.
    addpoints(lineLossTest1,episodeCt,mean(errorPreddiction))
    legend(axModelTest,'Model1');
    title(axModelTest, ...
        "Model one-step prediction error - Episode: " + episodeCt + ...
        ", Error: " + string(mean(errorPreddiction)))
    drawnow

    % Display training progress every 10th episode
    if (mod(episodeCt,10) == 0)    
        fprintf("EP:%d, Reward:%.4f, AveReward:%.4f, Steps:%d, TotalSteps:%d, epsilon:%f, error model:%f\n",...
            episodeCt,episodeCumulativeReward,movingAveReward(end),...
            stepCt,totalStepCt,epsilon,mean(errorPreddiction))
    end

    %-----------------------------------------------------------------------------
    % 9. Terminate training if the network is sufficiently trained.
    %-----------------------------------------------------------------------------
    if max(movingAveReward) > trainingTerminationValue
        break
    end
end
EP:10, Reward:12.0000, AveReward:13.3333, Steps:18, TotalSteps:261, epsilon:1.000000, error model:3.786379
EP:20, Reward:11.0000, AveReward:20.1667, Steps:17, TotalSteps:493, epsilon:1.000000, error model:3.768267
EP:30, Reward:34.0000, AveReward:19.3333, Steps:40, TotalSteps:769, epsilon:1.000000, error model:3.763075
EP:40, Reward:20.0000, AveReward:13.8333, Steps:26, TotalSteps:960, epsilon:1.000000, error model:3.797021
EP:50, Reward:13.0000, AveReward:22.5000, Steps:19, TotalSteps:1192, epsilon:1.000000, error model:3.813097
EP:60, Reward:32.0000, AveReward:14.8333, Steps:38, TotalSteps:1399, epsilon:1.000000, error model:3.821042
EP:70, Reward:12.0000, AveReward:17.6667, Steps:18, TotalSteps:1630, epsilon:1.000000, error model:3.741603
EP:80, Reward:17.0000, AveReward:16.5000, Steps:23, TotalSteps:1873, epsilon:1.000000, error model:3.780144
EP:90, Reward:13.0000, AveReward:11.8333, Steps:19, TotalSteps:2113, epsilon:0.567555, error model:0.222689
EP:100, Reward:168.0000, AveReward:209.3333, Steps:174, TotalSteps:3546, epsilon:0.010000, error model:0.266022
EP:110, Reward:200.0000, AveReward:350.1667, Steps:206, TotalSteps:6892, epsilon:0.010000, error model:0.074918
EP:120, Reward:208.0000, AveReward:317.1667, Steps:214, TotalSteps:10309, epsilon:0.010000, error model:0.057781
EP:130, Reward:500.0000, AveReward:275.8333, Steps:500, TotalSteps:13517, epsilon:0.010000, error model:0.032713

Figure Cart Pole Custom Training (Models) contains an axes object. The axes object with title Model Training Progress - Episode: 134, Elapsed: 00:35:58 contains 3 objects of type animatedline. These objects represent Model1, Model2, Model3.

Figure Cart Pole Custom Training (DQN agent) contains an axes object. The axes object with title Training Progress - Episode: 134, Total Step: 15516, epsilon:0.01 contains 2 objects of type animatedline. These objects represent Cumulative Reward, Average Reward.

Figure Cart Pole Custom Training (Models) contains an axes object. The axes object with title Model one-step prediction error - Episode: 134, Error: 0.030627 contains an object of type animatedline. This object represents Model1.

Симулируйте агента

Чтобы симулировать обученного агента, сначала сбросьте среду.

obs0 = reset(env);
obs = obs0;

Включите визуализацию среды, которая обновляется каждый раз, когда ступенчатая функция среды называется.

plot(env)

Для каждого шага симуляции выполните следующие действия.

  1. Получите действие путем выборки от политики с помощью getAction функция.

  2. Шаг среда с помощью полученного значения действия.

  3. Оконечный, если терминальное условие достигнуто.

actionVector = zeros(1,maxStepsPerEpisode);
obsVector = zeros(numObservations,maxStepsPerEpisode+1);
obsVector(:,1) = obs0;
for stepCt = 1:maxStepsPerEpisode
    
    % Select action according to trained policy.
    action = getAction(critic,{obs});
        
    % Step the environment.
    [nextObs,reward,isDone] = step(env,action);    

    obsVector(:,stepCt+1) = nextObs;
    actionVector(1,stepCt) = action;

    % Check for terminal condition.
    if isDone
        break
    end
    
    obs = nextObs;    
end

Figure Cart Pole Visualizer contains an axes object. The axes object contains 6 objects of type line, polygon.

lastStepCt = stepCt;

Тестовая модель

Протестируйте одну из моделей путем предсказания следующего наблюдения, учитывая текущее наблюдение и действие.

modelID = 3;
predictedObsVector = zeros(numObservations,lastStepCt);
obs = dlarray(obsVector(:,1),'CB');
predictedObsVector(:,1) = obs;
for stepCt = 1:lastStepCt
    obs = dlarray(obsVector(:,stepCt),'CB');
    action = dlarray(actionVector(1,stepCt),'CB');
    
    dx = predict(transitionNetworkVector(modelID),obs, action);
    predictedObs = obs + dx;
    predictedObsVector(:,stepCt+1) = predictedObs;    
end
predictedObsVector = predictedObsVector(:, 1:lastStepCt);
figure(5)
layOut = tiledlayout(4,1, 'TileSpacing', 'compact');
for i = 1:4
    nexttile;
    errorPrediction = abs(predictedObsVector(i,1:lastStepCt) - obsVector(i,1:lastStepCt));
    line1 = plot(errorPrediction,'DisplayName', 'Absolute Error');
    title("observation "+num2str(i));
end
title(layOut,"Prediction Absolute Error")

Figure contains 4 axes objects. Axes object 1 with title observation 1 contains an object of type line. This object represents Absolute Error. Axes object 2 with title observation 2 contains an object of type line. This object represents Absolute Error. Axes object 3 with title observation 3 contains an object of type line. This object represents Absolute Error. Axes object 4 with title observation 4 contains an object of type line. This object represents Absolute Error.

Небольшая абсолютная ошибка предсказания показывает, что модель успешно обучена предсказать следующее наблюдение.

Ссылки

[1] Владимир Мин, Koray Kavukcuoglu, Дэвид Сильвер, Алекс Грэйвс, Яннис Антоноглоу, Даан Вирстра и Мартин Ридмиллер. “Проигрывая Atari с Глубоким Обучением с подкреплением”. ArXiv:1312.5602 [Cs]. 19 декабря 2013. https://arxiv.org/abs/1312.5602.

[2] Janner, Майкл, Джастин Фу, Марвин Чжан и Сергей Левин. "Когда доверять вашей модели: основанная на модели оптимизация политики". ArXiv:1907.08253 [Cs, Статистика], 5 ноября 2019. https://arxiv.org/abs/1906.08253.

Похожие темы