強化學習筆記(三)-----值迭代算法

強化學習有兩種常見迭代訓練算法:策略迭代算法和值迭代算法。在上一篇博客<<強化學習筆記(二)>>中已經詳細描述了策略迭代算法,其實值迭代算法和策略迭代算法的基本思想是一致的,其最大的區別在於,策略迭代算法在進行策略改善的時候,使用的每個狀態的值函數,是穩定的,在進行策略評估的時候,計算得到了當前策略的穩定值函數;而值迭代算法交替進行策略評估和策略改善的過程,並不是等到值函數穩定的時候再進行策略改善,其過程更爲動態。

gym的樣例代碼展示

通過對下面的問題進行編程,加深對值迭代算法的理解。問題的描述同<<強化學習筆記(二)>>中的內容

這裏寫圖片描述

該圖的狀態空間由下置上,從左到右,分別爲1 – 36

import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering

#模擬環境類
class GridWorldEnv(gym.Env):
    #相關的全局配置
    metadata = {
        'render.modes':['human', 'rgb_array'],
        'video.frames_per_second': 2
    }

    def __init__(self):
        self.states = [i for i in range(1, 37)] #初始化狀態
        self.terminate_states = [3, 7, 11, 15, 19, 20, 23, 30,  33, 34] #終結態
        self.actions = ['up', 'down', 'left', 'right'] #動作空間

        self.v_states = dict() #狀態的值空間
        for state in self.states:
            self.v_states[state] = 0.0

        for state in self.terminate_states: #先將所有陷阱和黃金的值函數初始化爲-1.0
            self.v_states[state] = -1.0

        self.v_states[34] = 1.0  #黃金的位置值函數初始化爲 1

        self.initStateAction() #初始化每個狀態的可行動作空間
        self.initStatePolicyAction() #隨機初始化當前策略

        self.gamma = 0.8 #計算值函數用的折扣因子
        self.viewer = None #視圖對象
        self.current_state = None #當前狀態
        return

    def translateStateToRowCol(self, state):
        """
        將狀態轉化爲行列座標返回
        """
        row = (state - 1) // 6
        col = (state - 1) %  6
        return row, col

    def translateRowColToState(self, row, col):
        """
        將行列座標轉化爲狀態值
        """
        return row * 6 + col + 1

    def actionRowCol(self, row, col, action):
        """
        對行列座標執行動作action並返回座標
        """
        if action == "up":
            row = row - 1
        if action == "down":
            row = row + 1
        if action == "left":
            col = col - 1
        if action == "right":
            col = col + 1
        return row, col

    def canUp(self, row, col):
        row = row - 1
        return 0 <= row <= 5

    def canDown(self, row, col):
        row = row + 1
        return 0 <= row <= 5

    def canLeft(self, row, col):
        col = col - 1
        return 0 <= col <= 5

    def canRight(self, row, col):
        col = col + 1
        return 0 <= col <= 5

    def initStateAction(self):
        """
        初始化每個狀態可行動作空間
        """
        self.states_actions = dict()
        for state in self.states:
            self.states_actions[state] = []
            if state in self.terminate_states:
                continue
            row, col = self.translateStateToRowCol(state)
            if self.canUp(row, col):
                self.states_actions[state].append("up")
            if self.canDown(row, col):
                self.states_actions[state].append("down")
            if self.canLeft(row, col):
                self.states_actions[state].append('left')
            if self.canRight(row, col):
                self.states_actions[state].append('right')
        return


    def initStatePolicyAction(self):
        """
        初始化每個狀態的當前策略動作
        """
        self.states_policy_action = dict()
        for state in self.states:
            if state in self.terminate_states:
                self.states_policy_action[state] = None
            else:
                self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
        return


    def seed(self, seed = None):
        random.seed(seed)
        return [seed]

    def reset(self):
        """
        重置原始狀態
        """
        self.current_state = random.sample(self.states, 1)[0]

    def step(self, action):
        """
        動作迭代函數
        """
        cur_state = self.current_state
        if cur_state in self.terminate_states:
            return cur_state, 0, True, {}
        row, col = self.translateStateToRowCol(cur_state)
        n_row, n_col = self.actionRowCol(row, col, action)
        next_state = self.translateRowColToState(n_row, n_col)
        self.current_state = next_state
        if next_state in self.terminate_states:
            return next_state, 0, True, {}
        else:
            return next_state, 0, False, {}

    def policy_evaluate_and_improve(self):
        """
        遍歷狀態空間,對策略進行評估和改善
        """
        error = 0.0 #迭代的值函數誤差
        for state in self.states:
            if state in self.terminate_states:
                continue
            action = self.states_policy_action[state]
            self.current_state = state
            next_state, reward, isTerminate, info = self.step(action)
            new_value = reward + self.gamma * self.v_states[next_state]
            new_action = action

            for _action in self.states_actions[state]:
                self.current_state = state
                next_state, reward, isTerminate, info = self.step(_action)
                if new_value < reward + self.v_states[next_state]:
                    new_value = reward + self.gamma * self.v_states[next_state]
                    new_action = _action
            error = max(error, abs(new_value - self.v_states[state]))
            self.v_states[state] = new_value
            self.states_policy_action[state] = new_action
        return error


    def createGrids(self):
        """
        創建網格
        """
        start_x = 40
        start_y = 40
        line_length = 40
        for state in self.states:
            row, col = self.translateStateToRowCol(state)
            x = start_x + col * line_length
            y = start_y + row * line_length
            line = rendering.Line((x, y), (x + line_length, y))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y), (x, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x + line_length, y), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y + line_length), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)

    def createTraps(self):
        """
        創建陷阱,將黃金的位置也先繪製成陷阱,後面覆蓋畫成黃金
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        for state in self.terminate_states:
            row, col = self.translateStateToRowCol(state)
            trap = rendering.make_circle(20)
            trans = rendering.Transform()
            trap.add_attr(trans)
            trap.set_color(0, 0, 0)
            trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
            self.viewer.add_onetime(trap)

    def createGold(self):
        """
        創建黃金
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        state = 34
        row, col = self.translateStateToRowCol(state)
        gold = rendering.make_circle(20)
        trans = rendering.Transform()
        gold.add_attr(trans)
        gold.set_color(1, 0.9, 0)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(gold)

    def createRobot(self):
        """
        創建機器人
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        row, col = self.translateStateToRowCol(self.current_state)
        robot = rendering.make_circle(15)
        trans = rendering.Transform()
        robot.add_attr(trans)
        robot.set_color(1, 0, 1)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(robot)

    def render(self, mode="human", close=False):
        """
        渲染整個場景
        """
        #關閉視圖
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None

        #視圖的大小
        screen_width = 320
        screen_height = 320


        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

        #創建網格
        self.createGrids()
        #創建陷阱
        self.createTraps()
        #創建黃金
        self.createGold()
        #創建機器人
        self.createRobot()
        return self.viewer.render(return_rgb_array= mode == 'rgb_array')
註冊環境模擬類到gym
from gym.envs.registration import register
try:
    register(id = "GridWorld-v4", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
    pass
進行策略迭代算法的過程和模擬動畫的代碼
from time import sleep
env = gym.make('GridWorld-v4')
env.reset()

#策略評估和策略改善 
not_changed_count = 0
for _ in range(10000):
    error = env.env.policy_evaluate_and_improve()
    if error < 0.00001:
        break
else:
    #打印迭代的次數
    print("iter count:" + str(_))


#觀察env到底是個什麼東西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))

env.reset()

for _ in range(1000):
    env.render()
    if env.env.states_policy_action[env.env.current_state] is not None:
        observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
    else:
        done = True
    print(_)
    if done:
        sleep(0.5)
        env.render()
        env.reset()
        print("reset")
    sleep(0.5)
env.close()
動畫效果

這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章