蒙特卡洛樹搜索及實現三子棋遊戲
預備知識
雙人有限零和順序遊戲
MCTS運行所在的框架/環境是一個遊戲,它本身是一個非常抽象和寬泛的概念,因此這裏我們只關注一種遊戲類型:雙人有限零和順序遊戲。這個名詞一開始聽起來會有些複雜,但是實際上非常簡單,現在來讓我們將它分解一下:
遊戲:意味着我們在一種需要交互的情境中,交互通常會涉及一個或多個角色
有限:表明在任意時間點,角色之間存在的交互方式都是有限的
雙人:遊戲中只有兩個角色
順序:玩家依次交替進行他們的動作
零和:參與遊戲的兩方有完全相反的目標,換句話說就是,遊戲的任意結束狀態雙方的收益之和等於零
我們可以很輕鬆的驗證,圍棋、國際象棋和井字棋都是雙人有限零和順序遊戲:有兩位玩家參與,玩家能進行的動作總是有限的,雙方的遊戲目標是完全相反的(所有遊戲的結果之和等於0)原文鏈接:https://blog.csdn.net/qq_16137569/article/details/83543641
遊戲樹
遊戲樹是一種常見的數據結構,其中每一個節點代表遊戲的一個確定狀態,從一個節點到該節點的一個子節點(如果存在)是一個移動。節點的子節點數目稱爲分支因子。遊戲樹的根節點代表遊戲的初始狀態。遊戲樹的終端節點是沒有子節點的節點,至此遊戲結束,無法再進行移動。終端節點的狀態也就是遊戲的結果(輸/贏/平局)。
下面以井字棋遊戲爲例,形象地來看下什麼是遊戲樹。
每個父節點的子節點數量對應着本次可以執行的Action的數量
蒙特卡洛樹搜索
搜索流程圖
搜索步驟
-
選擇
從根節點開始,我們選擇採用UCB計算得到的最大的值的孩子節點,如此向下搜索,直到我們來到樹的底部的葉子節點(沒有孩子節點的節點),若果該節點沒有子節點,就會去執行擴展 -
擴展
到達葉子節點後,如果還沒有到達終止狀態,那麼我們就要對這個節點進行擴展(這裏是一個迭代過程),擴展出一個或多個節點。可以擴展一個節點也可以擴展多個節點. -
模擬
我們基於目前的這個狀態,根據某一種策略(例如random policy)進行模擬,直到遊戲結束爲止,產生結果,比如勝利或者失敗。此處的模擬可以指定模擬多少輪也可以指定模擬多少時間.所以模擬的本質還是用頻率去逼近概率 -
反向傳播
根據模擬的結果,我們要自底向上,反向更新所有節點的信息.一般需要更新的值有該節點被訪問的次數和該節點的獎勵值.若模擬結果爲勝利,則獎勵爲正,模擬結果爲失敗,則獎勵爲負.獎勵函數也可以設計的很複雜
每次搜索步驟需要N次的模擬,但只對應了一次下棋,每次下棋後都會更新狀態,並從新狀態開始(人也下完了棋),進行下一次的搜索.(下一步棋)
具體案例可以看博客
節點狀態
某個節點的所有子節點全都被訪問過,則該節點稱作完全擴展,否則就是未完全擴展.
圖中灰色的節點表示被擴展出來但是還沒有被訪問過
UCT計算
( 是節點被訪問的次數,而 則是其父節點已經被訪問的總次數)
UCT的第一部分是(總收益/總次數=平均每次的收益),即優先選擇收益大的.但只有這一項是不夠的,那些未被選中的節點之後就再也無法選到了,
UCT的第二部分是傾向於那些未被探索的節點,(子節點被探索的越少則分母越小,)
c是一個常數,用於平衡兩部分的值
何時停止
原則上,模擬的次數越多則結果越好,但在實際中往往會指定一個時間限制或是模擬次數限制,防止運行時間過長(比如跟你對戰的ai遲遲不下棋).在模擬結束後,最佳的移動通常是訪問次數最多的那個節點.
代碼實現
實現一個三子棋程序
其中蒙特卡洛樹代碼來自git
蒙特卡洛核心類
-
mcts類:
search方法對應模擬方法
executeRound方法定義了一次模擬流程
selectNode對應節點選擇,
- 該節點若有子節點,則使用getBestChild方法獲得UCT值最大的節點
- 若無子節點,則使用expand方法擴展子節點
rollout方法將在選擇的節點上隨機執行一種Action
backpropogate方法對應反向傳播
getBestChild,在n次executeRound執行完後,選擇子節點中最優的
getAction方法是從子節點中獲取其動作(下到哪裏)
-
treeNode類,用於構建樹形結構,存儲當前節點的狀態
-
randomPolicy方法:規定了rollout時使用哪種方式,一般使用隨機選擇的方式
from __future__ import division
import time
import math
import random
def randomPolicy(state):
while not state.isTerminal():
try:
action = random.choice(state.getPossibleActions())
except IndexError:
raise Exception("Non-terminal state has no possible actions: " + str(state))
state = state.takeAction(action)
return state.getReward()
class treeNode():
def __init__(self, state, parent):
self.state = state
self.isTerminal = state.isTerminal()
self.isFullyExpanded = self.isTerminal
self.parent = parent
self.numVisits = 0
self.totalReward = 0
self.children = {}
class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimit != None:
if iterationLimit != None:
raise ValueError("Cannot have both a time limit and an iteration limit")
# time taken for each MCTS search in milliseconds
self.timeLimit = timeLimit
self.limitType = 'time'
else:
if iterationLimit == None:
raise ValueError("Must have either a time limit or an iteration limit")
# number of iterations of the search
if iterationLimit < 1:
raise ValueError("Iteration limit must be greater than one")
self.searchLimit = iterationLimit
self.limitType = 'iterations'
self.explorationConstant = explorationConstant
self.rollout = rolloutPolicy
def search(self, initialState):
self.root = treeNode(initialState, None)
if self.limitType == 'time': # 時間限制
timeLimit = time.time() + self.timeLimit / 1000
while time.time() < timeLimit:
self.executeRound()
else: # 次數限制
for i in range(self.searchLimit):
self.executeRound()
# executeRound執行完後,其葉子節點就存放了他們的信息
bestChild = self.getBestChild(self.root, 0)
return self.getAction(self.root, bestChild)
def executeRound(self):
node = self.selectNode(self.root)
reward = self.rollout(node.state)
self.backpropogate(node, reward)
def selectNode(self, node):
while not node.isTerminal: # 這裏會一直找到遊戲結束,即最後一個節點
if node.isFullyExpanded:
node = self.getBestChild(node, self.explorationConstant)
else:
return self.expand(node) # 每次把所有的孩子都擴展出來
return node
def expand(self, node):
actions = node.state.getPossibleActions()
for action in actions:
if action not in node.children:
newNode = treeNode(node.state.takeAction(action), node)
node.children[action] = newNode
if len(actions) == len(node.children):
node.isFullyExpanded = True
return newNode
raise Exception("Should never reach here")
def backpropogate(self, node, reward):
while node is not None:
node.numVisits += 1
node.totalReward += reward
node = node.parent
def getBestChild(self, node, explorationValue):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = child.totalReward / child.numVisits + explorationValue * math.sqrt(
2 * math.log(node.numVisits) / child.numVisits)
if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
elif nodeValue == bestValue:
bestNodes.append(child)
return random.choice(bestNodes)
def getAction(self, root, bestChild):
for action, node in root.children.items():
if node is bestChild:
return action
狀態類
- Action類是動作類,封裝了執行的動作,比如下棋到哪個位置
- NaughtsAndCrossesState類是狀態類,要提供以下方法
- 維護玩家狀態: currentPlayer
- 維護棋盤狀態: board
- 提供一個獲得所有可行狀態的方法getPossibleActions
- 提供一個執行Action的方法takeAction,並且要更新自己的狀態
- 提供一個isTerminal函數,用於判斷遊戲是否結束
- 提供一個getReward方法,用於計算獎勵
from __future__ import division
from copy import deepcopy
from mcts import mcts
from functools import reduce
import operator
class NaughtsAndCrossesState(object):
def __init__(self):
self.target_num = 3 # 最終目標
self.board_width = 3
self.board = [[0] * self.board_width for _ in range(self.board_width)]
self.currentPlayer = 1
def getPossibleActions(self):
possibleActions = []
for i in range(len(self.board)):
for j in range(len(self.board[i])):
if self.board[i][j] == 0:
possibleActions.append(Action(player=self.currentPlayer, x=i, y=j))
return possibleActions
def takeAction(self, action):
newState = deepcopy(self)
newState.board[action.x][action.y] = action.player
newState.currentPlayer = self.currentPlayer * -1
return newState
def isTerminal(self):
for row in self.board:
if abs(sum(row)) == self.target_num:
return True
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == self.target_num:
return True
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == self.target_num:
return True
return reduce(operator.mul, sum(self.board, []), 1)
def getReward(self):
for row in self.board:
if abs(sum(row)) == self.target_num:
return sum(row) / self.target_num
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == self.target_num:
return sum(column) / self.target_num
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == self.target_num:
return sum(diagonal) / self.target_num
return False
class Action():
def __init__(self, player, x, y):
self.player = player
self.x = x
self.y = y
def __str__(self):
return str((self.x, self.y))
def __repr__(self):
return str(self)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.x == other.x and self.y == other.y and self.player == other.player
def __hash__(self):
return hash((self.x, self.y, self.player))
if __name__ == '__main__':
import numpy as np
s = NaughtsAndCrossesState()
tree = mcts(timeLimit=1000)
while True:
# 機器下棋
action = tree.search(initialState=s)
s = s.takeAction(action)
print("ai:", action)
print(np.array(s.board))
if s.isTerminal():
print("ai win")
break
# 人下棋
x, y = list(map(int, input().split()))
action = Action(-1, x, y)
s = s.takeAction(action)
print("人:", action)
print(np.array(s.board))
print(s.isTerminal())
if s.isTerminal():
print("human win")
break
我們在實際使用中,只需定義一個合適的State和Action類並實現其方法,就可以應用到mcts中.不要將State和Action耦合在MCTS中,這樣就沒有擴展性了