強化學習筆記(四)......基於時間差分的Sarsa強化學習方法

解決強化學習的訓練問題有很多種方法,本節用時間差分方法Sarsa來對一個簡單的迷宮問題進行求解。

迷宮問題的地圖簡單描述如下。

這裏寫圖片描述
同策略的Sarsa方法更新動作值函數更新公式如下:

這裏寫圖片描述
簡單的說明一下,就是通過概率模擬狀態s的時候,選擇執行動作a,到達了狀態s’,再利用狀態s’處的Q(s’,a’)來更新Q(s, a)的值,但是因爲是模擬,所以不能直接用Q(s,a) = r + yQ(s’,a’)來直接計算, 通過 r + yQ(s’,a’) - Q(s,a),會得到當前值函數Q(s,a)與最新模擬的值函數r + yQ(s’,a’)的偏差值,再將其一定比例的加到原來的Q(s, a)上,這個一定的比列你可以認爲是傳統的學習率。

代碼部分

import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering

#模擬環境類
class GridWorldEnv(gym.Env):
    #相關的全局配置
    metadata = {
        'render.modes':['human', 'rgb_array'],
        'video.frames_per_second': 2
    }

    def __init__(self):
        self.states = [i for i in range(1, 26)] #初始化狀態
        self.terminate_states = [3, 4, 5, 11, 12, 19, 24, 15] #終結態
        self.actions = ['up', 'down', 'left', 'right'] #動作空間

        self.value_of_state = dict() #狀態的值空間
        for state in self.states:
            self.value_of_state[state] = 0.0

        for state in self.terminate_states: #先將所有陷阱的值函數初始化爲-1.0
            self.value_of_state[state] = -1.0

        self.value_of_state[15] = 1.0  #黃金的位置值函數初始化爲 1

        self.initStateAction() #初始化每個狀態的可行動作空間
        self.initStatePolicyAction() #隨機初始化當前策略
        self.initQ_s_a()

        self.gamma = 0.8 #計算值函數用的折扣因子
        self.alpha = 0.1 #學習率
        self.viewer = None #視圖對象
        self.current_state = None #當前狀態
        return

    def translateStateToRowCol(self, state):
        """
        將狀態轉化爲行列座標返回
        """
        row = (state - 1) // 5
        col = (state - 1) %  5
        return row, col

    def translateRowColToState(self, row, col):
        """
        將行列座標轉化爲狀態值
        """
        return row * 5 + col + 1

    def actionRowCol(self, row, col, action):
        """
        對行列座標執行動作action並返回座標
        """
        if action == "up":
            row = row - 1
        if action == "down":
            row = row + 1
        if action == "left":
            col = col - 1
        if action == "right":
            col = col + 1
        return row, col

    def canUp(self, row, col):
        row = row - 1
        return 0 <= row <= 4

    def canDown(self, row, col):
        row = row + 1
        return 0 <= row <= 4

    def canLeft(self, row, col):
        col = col - 1
        return 0 <= col <= 4

    def canRight(self, row, col):
        col = col + 1
        return 0 <= col <= 4

    def initStateAction(self):
        """
        初始化每個狀態可行動作空間,並且初始化
        """
        self.states_actions = dict()
        for state in self.states:
            self.states_actions[state] = []
            if state in self.terminate_states:
                continue
            row, col = self.translateStateToRowCol(state)
            if self.canUp(row, col):
                self.states_actions[state].append("up")
            if self.canDown(row, col):
                self.states_actions[state].append("down")
            if self.canLeft(row, col):
                self.states_actions[state].append('left')
            if self.canRight(row, col):
                self.states_actions[state].append('right')
        return

    def initQ_s_a(self):
        """
        初始化Q值函數
        """
        self.Q_s_a = dict()
        for state in self.states:
            if state in self.terminate_states:
                continue
            for action in self.states_actions[state]:
                self.Q_s_a["%d_%s" % (state, action)] = 0.0 #初始化所有的行爲值函數


    def epsilon_greedy(self, state, epsilon):
        """
        概率模擬在狀態s,如何通過概率模擬得到下一步動作
        """
        action_size = len(self.states_actions[state])
        max_value_action = self.states_actions[state][0]
        for action in self.states_actions[state]:
            if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
                max_value_action = action
        prob_list = [0.0 for _ in range(0, action_size)]
        for i in range(0, action_size):
            if self.states_actions[state][i] == max_value_action:
                prob_list[i] = 1 - epsilon + epsilon / action_size
            else:
                prob_list[i] = epsilon / action_size
        r = random.random()
        s = 0.0
        for i in range(0, action_size):
            s += prob_list[i]
            if s >= r:
                return self.states_actions[state][i]
        return self.states_actions[state][-1]

    def greedy_action(self, state):
        """
        獲取最優策略
        """
        action_size = len(self.states_actions[state])
        max_value_action = self.states_actions[state][0]
        for action in self.states_actions[state]:
            if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
                max_value_action = action
        return max_value_action


    def initStatePolicyAction(self):
        """
        初始化每個狀態的當前策略動作
        """
        self.states_policy_action = dict()
        for state in self.states:
            if state in self.terminate_states:
                self.states_policy_action[state] = None
            else:
                self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
        return


    def seed(self, seed = None):
        random.seed(seed)
        return [seed]

    def reset(self):
        """
        重置原始狀態
        """
        self.current_state = random.sample(self.states, 1)[0]

    def step(self, action):
        """
        動作迭代函數
        """
        cur_state = self.current_state
        if cur_state in self.terminate_states:
            return cur_state, 0, True, {}
        row, col = self.translateStateToRowCol(cur_state)
        n_row, n_col = self.actionRowCol(row, col, action)
        next_state = self.translateRowColToState(n_row, n_col)
        self.current_state = next_state
        if next_state in self.terminate_states:
            return next_state, 0, True, {}
        else:
            return next_state, 0, False, {}

    def policy_evaluate_sarsa(self):
        """
        遍歷狀態空間,對策略進行評估和改善
        """
        for state in self.states:
            if state in self.terminate_states:
                continue
            for action in self.states_actions[state]:
                self.current_state = state
                next_state, reward, isTerminate, info = self.step(action)
                if isTerminate is True:
                    s_a = "%d_%s" % (state, action)
                    self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.value_of_state[next_state] - self.Q_s_a[s_a])
                else:
                    s_a = "%d_%s" % (state, action)
                    n_action = self.epsilon_greedy(next_state, 0.3)
                    n_s_a = "%d_%s" % (next_state, n_action)
                    self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.Q_s_a[n_s_a] - self.Q_s_a[s_a])
        return

    def policy_improve_sarsa(self):
        """
        策略提升
        """
        for state in self.states:
            if state in self.terminate_states:
                continue
            self.states_policy_action[state] =  self.greedy_action(state)
        return    

    def createGrids(self):
        """
        創建網格
        """
        start_x = 40
        start_y = 40
        line_length = 40
        for state in self.states:
            row, col = self.translateStateToRowCol(state)
            x = start_x + col * line_length
            y = start_y + row * line_length
            line = rendering.Line((x, y), (x + line_length, y))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y), (x, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x + line_length, y), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)
            line = rendering.Line((x, y + line_length), (x + line_length, y  + line_length))
            line.set_color(0, 0, 0)
            self.viewer.add_onetime(line)

    def createTraps(self):
        """
        創建陷阱,將黃金的位置也先繪製成陷阱,後面覆蓋畫成黃金
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        for state in self.terminate_states:
            row, col = self.translateStateToRowCol(state)
            trap = rendering.make_circle(20)
            trans = rendering.Transform()
            trap.add_attr(trans)
            trap.set_color(0, 0, 0)
            trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
            self.viewer.add_onetime(trap)

    def createGold(self):
        """
        創建黃金,在這個問題中指的是出口
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        state = 15
        row, col = self.translateStateToRowCol(state)
        gold = rendering.make_circle(20)
        trans = rendering.Transform()
        gold.add_attr(trans)
        gold.set_color(1, 0.9, 0)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(gold)

    def createRobot(self):
        """
        創建機器人
        """
        start_x = 40 
        start_y = 40
        line_length = 40
        row, col = self.translateStateToRowCol(self.current_state)
        robot = rendering.make_circle(15)
        trans = rendering.Transform()
        robot.add_attr(trans)
        robot.set_color(1, 0, 1)
        trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
        self.viewer.add_onetime(robot)

    def render(self, mode="human", close=False):
        """
        渲染整個場景
        """
        #關閉視圖
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None

        #視圖的大小
        screen_width = 280
        screen_height = 280


        if self.viewer is None:
            self.viewer = rendering.Viewer(screen_width, screen_height)

        #創建網格
        self.createGrids()
        #創建陷阱
        self.createTraps()
        #創建黃金
        self.createGold()
        #創建機器人
        self.createRobot()
        return self.viewer.render(return_rgb_array= mode == 'rgb_array')

註冊類到gym

from gym.envs.registration import register
try:
    register(id = "GridWorld-v5", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
    pass

動畫模擬

from time import sleep
env = gym.make('GridWorld-v5')
env.reset()

#策略評估和策略改善 
for _ in range(100000):
    env.env.policy_evaluate_sarsa()
env.env.policy_improve_sarsa()



#觀察env到底是個什麼東西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))

env.reset()

for _ in range(1000):
    env.render()
    if env.env.states_policy_action[env.env.current_state] is not None:
        observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
    else:
        done = True
    print(_)
    if done:
        sleep(0.5)
        env.render()
        env.reset()
        print("reset")
    sleep(0.5)
env.close()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章