強化學習有兩種常見迭代訓練算法:策略迭代算法和值迭代算法。本文中主要講述策略迭代算法。
先從一個簡答的問題開始,下圖爲一個四方格子,每個位置的狀態空間分別爲{1, 2, 3, 4}, 其中 3 的位置是個陷阱, 4的位置有個金幣。有一個機器人從狀態1的位置開始尋找金幣。落入陷阱的回報爲-1,找到金幣的回報爲1,在其他位置間移動回報爲0,可選的動作空間爲{上,下,左,右}, 通過這個簡單的問題,來學習強化學習的學習原理。
強化學習的學習過程,個人理解就是通過不斷的嘗試,去更新每個狀態的值函數(每個狀態的值代表了當前狀態的優劣,如果狀態值很大,從其他狀態選擇一個動作,轉移到該狀態便是一個正確的選擇),然後通過更新後的值函數去動態的調整策略,在調整策略後,又去更新值函數,不斷的迭代更新,最後訓練完成一個滿足要求的策略。在這個過程中,抽象出兩個主要的過程,第一個叫策略評估,第二個叫策略改善。
針對上面給出的簡單問題,先說明一些簡單的概念:
每個狀態的值函數:
代表機器人處於該狀態時的優劣值。
針對問題的當前策略:
代表機器人處於某狀態時,選擇的下一步動作。對於選擇的下一步動作,可以是確定式的,比如當機器人處於1位置的時候,確定的只選擇往右走。也可以是概率式的,可以0.5的概率選擇往右走, 0.5的概率選擇往下走。當然確定式策略選擇是概率式的策略選擇的一種特例。下文中採用確定式策略進行描述
策略評估:
策略評估就是通過某種方式,計算狀態空間中每個狀態的值函數。由於狀態空間之間存在很多轉移關係,要直接計算某個狀態的值函數,是很困難的,一般採用
迭代方法。
策略改善:
對策略的改善,即通過當前擁有的信息,對當前策略進行優化,修改當前策略。
############################## 策略評估的過程
初始化的策略和值函數。
對於這個簡單的例子,通過一步計算便得到了穩定的值函數,但是對於大多數的問題,都需要通過多步的迭代,才能得到穩定的值函數。
############################## 策略改善的過程
對於這個簡單的例子,採用貪心的方式對策略進行改善,通過上一步策略評估過程計算出的穩定的值函數,讓每個狀態在選擇下一步動作的時候,選擇使動作收益最大的動作。
總結
強化學習策略迭代算法的過程就是不斷的重複 策略評估 和 策略改善的過程,直到整個策略收斂(值函數和策略不再發生大的變化)
gym的樣例代碼展示
上述的簡單示例,簡單描述了強化學習的策略迭代算法的過程,下面將問題搞複雜一點,通過對該問題進行編程,加深對策略迭代算法的理解。新問題的空間狀態如下圖所示。
該圖的狀態空間由下置上,從左到右,分別爲1 – 36
import numpy as np
import random
from gym import spaces
import gym
from gym.envs.classic_control import rendering
#模擬環境類
class GridWorldEnv(gym.Env):
#相關的全局配置
metadata = {
'render.modes':['human', 'rgb_array'],
'video.frames_per_second': 2
}
def __init__(self):
self.states = [i for i in range(1, 37)] #初始化狀態
self.terminate_states = [3, 7, 11, 15, 19, 20, 23, 30, 33, 34] #終結態
self.actions = ['up', 'down', 'left', 'right'] #動作空間
self.v_states = dict() #狀態的值空間
for state in self.states:
self.v_states[state] = 0.0
for state in self.terminate_states: #先將所有陷阱和黃金的值函數初始化爲-1.0
self.v_states[state] = -1.0
self.v_states[34] = 1.0 #黃金的位置值函數初始化爲 1
self.initStateAction() #初始化每個狀態的可行動作空間
self.initStatePolicyAction() #隨機初始化當前策略
self.gamma = 0.8 #計算值函數用的折扣因子
self.viewer = None #視圖對象
self.current_state = None #當前狀態
return
def translateStateToRowCol(self, state):
"""
將狀態轉化爲行列座標返回
"""
row = (state - 1) // 6
col = (state - 1) % 6
return row, col
def translateRowColToState(self, row, col):
"""
將行列座標轉化爲狀態值
"""
return row * 6 + col + 1
def actionRowCol(self, row, col, action):
"""
對行列座標執行動作action並返回座標
"""
if action == "up":
row = row - 1
if action == "down":
row = row + 1
if action == "left":
col = col - 1
if action == "right":
col = col + 1
return row, col
def canUp(self, row, col):
row = row - 1
return 0 <= row <= 5
def canDown(self, row, col):
row = row + 1
return 0 <= row <= 5
def canLeft(self, row, col):
col = col - 1
return 0 <= col <= 5
def canRight(self, row, col):
col = col + 1
return 0 <= col <= 5
def initStateAction(self):
"""
初始化每個狀態可行動作空間
"""
self.states_actions = dict()
for state in self.states:
self.states_actions[state] = []
if state in self.terminate_states:
continue
row, col = self.translateStateToRowCol(state)
if self.canUp(row, col):
self.states_actions[state].append("up")
if self.canDown(row, col):
self.states_actions[state].append("down")
if self.canLeft(row, col):
self.states_actions[state].append('left')
if self.canRight(row, col):
self.states_actions[state].append('right')
return
def initStatePolicyAction(self):
"""
初始化每個狀態的當前策略動作
"""
self.states_policy_action = dict()
for state in self.states:
if state in self.terminate_states:
self.states_policy_action[state] = None
else:
self.states_policy_action[state] = random.sample(self.states_actions[state], 1)[0]
return
def seed(self, seed = None):
random.seed(seed)
return [seed]
def reset(self):
"""
重置原始狀態
"""
self.current_state = random.sample(self.states, 1)[0]
def step(self, action):
"""
動作迭代函數
"""
cur_state = self.current_state
if cur_state in self.terminate_states:
return cur_state, 0, True, {}
row, col = self.translateStateToRowCol(cur_state)
n_row, n_col = self.actionRowCol(row, col, action)
next_state = self.translateRowColToState(n_row, n_col)
self.current_state = next_state
if next_state in self.terminate_states:
return next_state, 0, True, {}
else:
return next_state, 0, False, {}
def policy_evaluate(self):
"""
策略評估過程
"""
error = 0.000001 #誤差率
for _ in range(1000):
max_error = 0.0 #初始化最大誤差
for state in self.states:
if state in self.terminate_states:
continue
action = self.states_policy_action[state]
self.current_state = state
next_state, reward, isTerminate, info = self.step(action)
old_value = self.v_states[state]
self.v_states[state] = reward + self.gamma * self.v_states[next_state]
abs_error = abs(self.v_states[state] - old_value)
max_error = abs_error if abs_error > max_error else max_error #更新最大值
if max_error < error:
break
def policy_improve(self):
"""
根據策略評估的結果,進行策略更新,並返回每個狀態的當前策略是否發生了變化
"""
changed = False
for state in self.states:
if state in self.terminate_states:
continue
max_value_action = self.states_actions[state][0] #當前最大值行爲
max_value = -1000000000000.0 #當前最大回報
for action in self.states_actions[state]:
self.current_state = state
next_state, reward, isTerminate, info = self.step(action)
q_reward = reward + self.gamma * self.v_states[next_state]
if q_reward > max_value:
max_value_action = action
max_value = q_reward
if self.states_policy_action[state] != max_value_action:
changed = True
self.states_policy_action[state] = max_value_action
return changed
def createGrids(self):
"""
創建網格
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.states:
row, col = self.translateStateToRowCol(state)
x = start_x + col * line_length
y = start_y + row * line_length
line = rendering.Line((x, y), (x + line_length, y))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y), (x, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x + line_length, y), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
line = rendering.Line((x, y + line_length), (x + line_length, y + line_length))
line.set_color(0, 0, 0)
self.viewer.add_onetime(line)
def createTraps(self):
"""
創建陷阱,將黃金的位置也先繪製成陷阱,後面覆蓋畫成黃金
"""
start_x = 40
start_y = 40
line_length = 40
for state in self.terminate_states:
row, col = self.translateStateToRowCol(state)
trap = rendering.make_circle(20)
trans = rendering.Transform()
trap.add_attr(trans)
trap.set_color(0, 0, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(trap)
def createGold(self):
"""
創建黃金
"""
start_x = 40
start_y = 40
line_length = 40
state = 34
row, col = self.translateStateToRowCol(state)
gold = rendering.make_circle(20)
trans = rendering.Transform()
gold.add_attr(trans)
gold.set_color(1, 0.9, 0)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(gold)
def createRobot(self):
"""
創建機器人
"""
start_x = 40
start_y = 40
line_length = 40
row, col = self.translateStateToRowCol(self.current_state)
robot = rendering.make_circle(15)
trans = rendering.Transform()
robot.add_attr(trans)
robot.set_color(1, 0, 1)
trans.set_translation(start_x + line_length * col + 20, start_y + line_length * row + 20)
self.viewer.add_onetime(robot)
def render(self, mode="human", close=False):
"""
渲染整個場景
"""
#關閉視圖
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
#視圖的大小
screen_width = 320
screen_height = 320
if self.viewer is None:
self.viewer = rendering.Viewer(screen_width, screen_height)
#創建網格
self.createGrids()
#創建陷阱
self.createTraps()
#創建黃金
self.createGold()
#創建機器人
self.createRobot()
return self.viewer.render(return_rgb_array= mode == 'rgb_array')
註冊環境模擬類到gym
from gym.envs.registration import register
try:
register(id = "GridWorld-v3", entry_point=GridWorldEnv, max_episode_steps = 200, reward_threshold=100.0)
except:
pass
進行策略迭代算法的過程和模擬動畫的代碼
from time import sleep
env = gym.make('GridWorld-v3')
env.reset()
#策略評估和策略改善
not_changed_count = 0
for _ in range(10000):
env.env.policy_evaluate()
changed = env.env.policy_improve()
if changed:
not_changed_count = 0
else:
not_changed_count += 1
if not_changed_count == 10: #超過10次策略沒有再更新,說明策略已經穩定了
break
#觀察env到底是個什麼東西的打印信息。
print(isinstance(env, GridWorldEnv))
print(type(env))
print(env.__dict__)
print(isinstance(env.env, GridWorldEnv))
env.reset()
for _ in range(1000):
env.render()
if env.env.states_policy_action[env.env.current_state] is not None:
observation,reward,done,info = env.step(env.env.states_policy_action[env.env.current_state])
else:
done = True
print(_)
if done:
sleep(0.5)
env.render()
env.reset()
print("reset")
sleep(0.5)
env.close()