之前有實現Q-Learning走迷宮,本篇實現SARSA走迷宮。
Q-Learning是一種off-policy算法,當前步採取的決策action不直接作用於環境生成下一次state,而是選擇最優的獎勵來更新Q表。
更新公式:
SARSA是一種on-policy算法,當前步採取的策略action既直接作用於環境生成新的state,也用來更新Q表。
更新公式:
其中s是當前狀態,a是當前動作,s’是下次狀態,a'是下次動作。
代碼如下:
import numpy as np import random import matplotlib.pyplot as plt from PIL import Image import imageio import io H = 30 W = 40 start = (0, random.randint(0, H-1)) goal = (W-1, random.randint(0, H-1)) img = Image.new('RGB', (W, H), (255, 255, 255)) pixels = img.load() maze = np.zeros((W, H)) for h in range(H): for w in range(W): if random.random() < 0.1: maze[w, h] = -1 actions_num = 4 actions = [0, 1, 2, 3] q_table = np.zeros((W, H, actions_num)) rate = 0.5 factor = 0.9 images = [] for i in range(2000): state = start path = [start] action = np.random.choice(actions) while(True): next_state = None #執行該動作 if action == 0 and state[0] > 0: next_state = (state[0]-1, state[1]) elif action == 1 and state[0] < W-1: next_state = (state[0]+1, state[1]) elif action == 2 and state[1] > 0: next_state = (state[0], state[1]-1) elif action == 3 and state[1] < H-1: next_state = (state[0], state[1]+1) else: next_state = state if next_state == goal: #得到reward,到目標給大正反饋 reward = 100 elif maze[next_state] == -1: reward = -100 #遇見障礙物給大負反饋 else: reward = -1 #走一步給小負反饋,走的步數越小,負反饋越小 done = (next_state == goal) if np.random.rand() < 1.0/(i+1): #隨機或者下一個狀態最大q值對應的動作 next_action = np.random.choice(actions) else: next_action = np.argmax(q_table[next_state]) current_q = q_table[state][action] #根據公式更新qtable q_table[state][action] += rate * (reward + factor * q_table[next_state][next_action] - current_q) state = next_state action = next_action path.append(state) if done: break if i % 10 == 0: #每10次看結果 for h in range(H): for w in range(W): if maze[w,h]==-1: pixels[w, h] = (0, 0, 0) else: pixels[w, h] = (255, 255, 255) for x, y in path: pixels[x, y] = (0, 0, 255) pixels[start] = (255, 0, 0) pixels[goal] = (0, 255, 0) plt.clf() # 清除當前圖形 plt.imshow(img) plt.pause(0.1) # 暫停0.1秒,顯示動態效果 buf = io.BytesIO() plt.savefig(buf, format='png') # 保存圖像到內存中 buf.seek(0) # 將文件指針移動到文件開頭 images.append(imageio.imread(buf)) # 從內存中讀取圖像並添加到列表中 plt.show() imageio.mimsave('result.gif', images, fps=3) # 保存爲 GIF 圖像,幀率爲3
效果似乎沒有Q-Learning好。