解決強化學習的訓練問題有很多種方法,本節用時間差分方法Sarsa來對一個簡單的迷宮問題進行求解。
迷宮問題的地圖簡單描述如下。
同策略的Sarsa方法更新動作值函數更新公式如下:
簡單的說明一下,就是通過概率模擬狀態s的時候,選擇執行動作a,到達了狀態s’,再利用狀態s’處的Q(s’,a’)來更新Q(s, a)的值,但是因爲是模擬,所以不能直接用Q(s,a) = r + yQ(s’,a’)來直接計算, 通過 r + yQ(s’,a’) - Q(s,a),會得到當前值函數Q(s,a)與最新模擬的值函數r + yQ(s’,a’)的偏差值,再將其一定比例的加到原來的Q(s, a)上,這個一定的比列你可以認爲是傳統的學習率。
代碼部分
import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering
#模擬環境類
class GridWorldEnv(gym.Env):
#相關的全局配置
metadata = {
'render.modes':['human', 'rgb_array'],
'video.frames_per_second': 2
}
def __init__(self):
self.states = [i for i in range(1, 26)] #初始化狀態
self.terminate_states = [3, 4, 5, 11, 12, 19, 24, 15] #終結態
self.actions = ['up', 'down', 'left', 'right'] #動作空間
self.value_of_state = dict() #狀態的值空間
for state in self.states:
self.value_of_state[state] = 0.0
for state in self.terminate_states: #先將所有陷阱的值函數初始化爲-1.0
self.value_of_state[state] = -1.0
self.value_of_state[15] = 1.0 #黃金的位置值函數初始化爲 1
self.initStateAction() #初始化每個狀態的可行動作空間
self.initStatePolicyAction() #隨機初始化當前策略
self.initQ_s_a()
self.gamma = 0.8 #計算值函數用的折扣因子
self.alpha = 0.1 #學習率
self.viewer = None #視圖對象
self.current_state = None #當前狀態
return
def translateStateToRowCol(self, state):
"""
將狀態轉化爲行列座標返回
"""
row = (state - 1) // 5
col = (state - 1) % 5
return row, col
def translateRowColToState(self, row, col):
"""
將行列座標轉化爲狀態值
"""
return row * 5 + col + 1
def actionRowCol(self, row, col, action):
"""
對行列座標執行動作action並返回座標
"""
if action == "up":
row = row - 1
if action == "down":
row = row + 1
if action == "left":
col = col - 1
if action == "right":
col = col + 1
return row, col
def canUp(self, row, col):
row = row - 1
return 0 <= row <= 4
def canDown(self, row, col):
row = row + 1
return 0 <= row <= 4
def canLeft(self, row, col):
col = col - 1
return 0 <= col <= 4
def canRight(self, row, col):
col = col + 1
return 0 <= col <= 4
def initStateAction(self):
"""
初始化每個狀態可行動作空間,並且初始化
"""
self.states_actions = dict()
for state in self.states:
self.states_actions[state] = []
if state in self.terminate_states:
continue
row, col = self.translateStateToRowCol(state)
if self.canUp(row, col):
self.states_actions[state].append("up")
if self.canDown(row, col):
self.states_actions[state].append("down")
if self.canLeft(row, col):
self.states_actions[state].append('left')
if self.canRight(row, col):
self.states_actions[state].append('right')
return
def initQ_s_a(self):
"""
初始化Q值函數
"""
self.Q_s_a = dict()
for state in self.states:
if state in self.terminate_states:
continue
for action in self.states_actions[state]:
self.Q_s_a["%d_%s" % (state, action)] = 0.0 #初始化所有的行爲值函數
def epsilon_greedy(self, state, epsilon):
"""
概率模擬在狀態s,如何通過概率模擬得到下一步動作
"""
action_size = len(self.states_actions[state])
max_value_action = self.states_actions[state][0]
for action in self.states_actions[state]:
if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
max_value_action = action
prob_list = [0.0 for _ in range(0, action_size)]
for i in range(0, action_size):
if self.states_actions[state][i] == max_value_action:
prob_list[i] = 1 - epsilon + epsilon / action_size
else:
prob_list[i] = epsilon / action_size
r = random.random()
s = 0.0
for i in range(0, action_size):
s += prob_list[i]
if s >= r:
return self.states_actions[state][i]
return self.states_actions[state][-1]
def greedy_action(self, state):
"""
獲取最優策略
"""
action_size = len(self.states_actions[state])
max_value_action = self.states_actions[state][0]
for action in self.states_actions[state]:
if self.Q_s_a["%d_%s" % (state, action)] > self.Q_s_a["%d_%s" % (state, max_value_action)]:
max_value_action = action
return max_value_action
def initStatePolicyAction(self):
"""
初始化每個狀態的當前策略動作
"""
self.states_policy_action = dict()
for state in self.states:
if state in self.terminate_states:
self.states_policy_action[state] = None
else:
self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
return
def seed(self, seed = None):
random.seed(seed)
return [seed]
def reset(self):
"""
重置原始狀態
"""
self.current_state = random.sample(self.states, 1)[0]
def step(self, action):
"""
動作迭代函數
"""
cur_state = self.current_state
if cur_state in self.terminate_states:
return cur_state, 0, True, {}
row, col = self.translateStateToRowCol(cur_state)
n_row, n_col = self.actionRowCol(row, col, action)
next_state = self.translateRowColToState(n_row, n_col)
self.current_state = next_state
if next_state in self.terminate_states:
return next_state, 0, True, {}
else:
return next_state, 0, False, {}
def policy_evaluate_sarsa(self):
"""
遍歷狀態空間,對策略進行評估和改善
"""
for state in self.states:
if state in self.terminate_states:
continue
for action in self.states_actions[state]:
self.current_state = state
next_state, reward, isTerminate, info = self.step(action)
if isTerminate is True:
s_a = "%d_%s" % (state, action)
self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.value_of_state[next_state] - self.Q_s_a[s_a])
else:
s_a = "%d_%s" % (state, action)
n_action = self.epsilon_greedy(next_state, 0.3)
n_s_a = "%d_%s" % (next_state, n_action)
self.Q_s_a[s_a] = self.Q_s_a[s_a] + self.alpha * (reward + self.gamma * self.Q_s_a[n_s_a] - self.Q_s_a[s_a])
return
def policy_improve_sarsa(self):
"""
策略提升
"""
for state in self.states:
if state in self.terminate_states:
continue
self.states_policy_action[state] = self.greedy_action(state)
return
def createGrids(self):
"""
創建網格
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.states:
row, col = self.translateStateToRowCol(state)
x = start_x + col * line_length
y = start_y + row * line_length
line = rendering.Line((x, y), (x + line_length, y))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y), (x, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x + line_length, y), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y + line_length), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
def createTraps(self):
"""
創建陷阱,將黃金的位置也先繪製成陷阱,後面覆蓋畫成黃金
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.terminate_states:
row, col = self.translateStateToRowCol(state)
trap = rendering.make_circle(20)
trans = rendering.Transform()
trap.add_attr(trans)
trap.set_color(0, 0, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(trap)
def createGold(self):
"""
創建黃金,在這個問題中指的是出口
"""
start_x = 40
start_y = 40
line_length = 40
state = 15
row, col = self.translateStateToRowCol(state)
gold = rendering.make_circle(20)
trans = rendering.Transform()
gold.add_attr(trans)
gold.set_color(1, 0.9, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(gold)
def createRobot(self):
"""
創建機器人
"""
start_x = 40
start_y = 40
line_length = 40
row, col = self.translateStateToRowCol(self.current_state)
robot = rendering.make_circle(15)
trans = rendering.Transform()
robot.add_attr(trans)
robot.set_color(1, 0, 1)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(robot)
def render(self, mode="human", close=False):
"""
渲染整個場景
"""
#關閉視圖
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
#視圖的大小
screen_width = 280
screen_height = 280
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
#創建網格
self.createGrids()
#創建陷阱
self.createTraps()
#創建黃金
self.createGold()
#創建機器人
self.createRobot()
return self.viewer.render(return_rgb_array= mode == 'rgb_array')
註冊類到gym
from gym.envs.registration import register
try:
register(id = "GridWorld-v5", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
pass
動畫模擬
from time import sleep
env = gym.make('GridWorld-v5')
env.reset()
#策略評估和策略改善
for _ in range(100000):
env.env.policy_evaluate_sarsa()
env.env.policy_improve_sarsa()
#觀察env到底是個什麼東西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))
env.reset()
for _ in range(1000):
env.render()
if env.env.states_policy_action[env.env.current_state] is not None:
observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
else:
done = True
print(_)
if done:
sleep(0.5)
env.render()
env.reset()
print("reset")
sleep(0.5)
env.close()