蒙特卡洛樹搜索及實現三子棋遊戲

蒙特卡洛樹搜索及實現三子棋遊戲

預備知識

雙人有限零和順序遊戲

MCTS運行所在的框架/環境是一個遊戲,它本身是一個非常抽象和寬泛的概念,因此這裏我們只關注一種遊戲類型:雙人有限零和順序遊戲。這個名詞一開始聽起來會有些複雜,但是實際上非常簡單,現在來讓我們將它分解一下:

遊戲:意味着我們在一種需要交互的情境中,交互通常會涉及一個或多個角色
有限:表明在任意時間點,角色之間存在的交互方式都是有限的
雙人:遊戲中只有兩個角色
順序:玩家依次交替進行他們的動作
零和:參與遊戲的兩方有完全相反的目標,換句話說就是,遊戲的任意結束狀態雙方的收益之和等於零
  我們可以很輕鬆的驗證,圍棋、國際象棋和井字棋都是雙人有限零和順序遊戲:有兩位玩家參與,玩家能進行的動作總是有限的,雙方的遊戲目標是完全相反的(所有遊戲的結果之和等於0)

原文鏈接:https://blog.csdn.net/qq_16137569/article/details/83543641

遊戲樹

遊戲樹是一種常見的數據結構,其中每一個節點代表遊戲的一個確定狀態,從一個節點到該節點的一個子節點(如果存在)是一個移動。節點的子節點數目稱爲分支因子。遊戲樹的根節點代表遊戲的初始狀態。遊戲樹的終端節點是沒有子節點的節點,至此遊戲結束,無法再進行移動。終端節點的狀態也就是遊戲的結果(輸/贏/平局)。

下面以井字棋遊戲爲例,形象地來看下什麼是遊戲樹。

每個父節點的子節點數量對應着本次可以執行的Action的數量

蒙特卡洛樹搜索

搜索流程圖

搜索步驟
  1. 選擇
    從根節點開始,我們選擇採用UCB計算得到的最大的值的孩子節點,如此向下搜索,直到我們來到樹的底部的葉子節點(沒有孩子節點的節點),若果該節點沒有子節點,就會去執行擴展

  2. 擴展
    到達葉子節點後,如果還沒有到達終止狀態,那麼我們就要對這個節點進行擴展(這裏是一個迭代過程),擴展出一個或多個節點。可以擴展一個節點也可以擴展多個節點.

  3. 模擬
    我們基於目前的這個狀態,根據某一種策略(例如random policy)進行模擬,直到遊戲結束爲止,產生結果,比如勝利或者失敗。此處的模擬可以指定模擬多少輪也可以指定模擬多少時間.所以模擬的本質還是用頻率去逼近概率

  4. 反向傳播

    根據模擬的結果,我們要自底向上,反向更新所有節點的信息.一般需要更新的值有該節點被訪問的次數和該節點的獎勵值.若模擬結果爲勝利,則獎勵爲正,模擬結果爲失敗,則獎勵爲負.獎勵函數也可以設計的很複雜

每次搜索步驟需要N次的模擬,但只對應了一次下棋,每次下棋後都會更新狀態,並從新狀態開始(人也下完了棋),進行下一次的搜索.(下一步棋)

具體案例可以看博客

節點狀態

某個節點的所有子節點全都被訪問過,則該節點稱作完全擴展,否則就是未完全擴展.

圖中灰色的節點表示被擴展出來但是還沒有被訪問過

UCT計算

UCT(vi,v)=Q(vi)N(vi)+clog(N(v))N(vi) \mathbb{U C} \mathbb{T}\left(v_{i}, v\right)=\frac{Q\left(v_{i}\right)}{N\left(v_{i}\right)}+c \sqrt{\frac{\log (N(v))}{N\left(v_{i}\right)}}

(N(vi)N{(vi)} 是節點被訪問的次數,而 N(v)N( v) 則是其父節點已經被訪問的總次數)

UCT的第一部分是(總收益/總次數=平均每次的收益),即優先選擇收益大的.但只有這一項是不夠的,那些未被選中的節點之後就再也無法選到了,

UCT的第二部分是傾向於那些未被探索的節點,(子節點被探索的越少則分母越小,)

c是一個常數,用於平衡兩部分的值

何時停止

原則上,模擬的次數越多則結果越好,但在實際中往往會指定一個時間限制或是模擬次數限制,防止運行時間過長(比如跟你對戰的ai遲遲不下棋).在模擬結束後,最佳的移動通常是訪問次數最多的那個節點.

代碼實現

實現一個三子棋程序

其中蒙特卡洛樹代碼來自git

蒙特卡洛核心類
  1. mcts類:

    search方法對應模擬方法

    executeRound方法定義了一次模擬流程

    selectNode對應節點選擇,

    • 該節點若有子節點,則使用getBestChild方法獲得UCT值最大的節點
    • 若無子節點,則使用expand方法擴展子節點

    rollout方法將在選擇的節點上隨機執行一種Action

    backpropogate方法對應反向傳播

    getBestChild,在n次executeRound執行完後,選擇子節點中最優的

    getAction方法是從子節點中獲取其動作(下到哪裏)

  2. treeNode類,用於構建樹形結構,存儲當前節點的狀態

  3. randomPolicy方法:規定了rollout時使用哪種方式,一般使用隨機選擇的方式

from __future__ import division

import time
import math
import random


def randomPolicy(state):
    while not state.isTerminal():
        try:
            action = random.choice(state.getPossibleActions())
        except IndexError:
            raise Exception("Non-terminal state has no possible actions: " + str(state))
        state = state.takeAction(action)
    return state.getReward()


class treeNode():
    def __init__(self, state, parent):
        self.state = state
        self.isTerminal = state.isTerminal()
        self.isFullyExpanded = self.isTerminal
        self.parent = parent
        self.numVisits = 0
        self.totalReward = 0
        self.children = {}


class mcts():
    def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
                 rolloutPolicy=randomPolicy):
        if timeLimit != None:
            if iterationLimit != None:
                raise ValueError("Cannot have both a time limit and an iteration limit")
            # time taken for each MCTS search in milliseconds
            self.timeLimit = timeLimit
            self.limitType = 'time'
        else:
            if iterationLimit == None:
                raise ValueError("Must have either a time limit or an iteration limit")
            # number of iterations of the search
            if iterationLimit < 1:
                raise ValueError("Iteration limit must be greater than one")
            self.searchLimit = iterationLimit
            self.limitType = 'iterations'
        self.explorationConstant = explorationConstant
        self.rollout = rolloutPolicy

    def search(self, initialState):
        self.root = treeNode(initialState, None)

        if self.limitType == 'time':  # 時間限制
            timeLimit = time.time() + self.timeLimit / 1000
            while time.time() < timeLimit:
                self.executeRound()
        else:  # 次數限制
            for i in range(self.searchLimit):
                self.executeRound()
		# executeRound執行完後,其葉子節點就存放了他們的信息
        bestChild = self.getBestChild(self.root, 0)
        return self.getAction(self.root, bestChild)

    def executeRound(self):
        node = self.selectNode(self.root) 
        reward = self.rollout(node.state)
        self.backpropogate(node, reward)

    def selectNode(self, node):
        while not node.isTerminal: # 這裏會一直找到遊戲結束,即最後一個節點
            if node.isFullyExpanded:
                node = self.getBestChild(node, self.explorationConstant)
            else:
                return self.expand(node)  # 每次把所有的孩子都擴展出來
        return node

    def expand(self, node):
        actions = node.state.getPossibleActions()
        for action in actions:
            if action not in node.children:
                newNode = treeNode(node.state.takeAction(action), node) 
                node.children[action] = newNode
                if len(actions) == len(node.children):
                    node.isFullyExpanded = True
                return newNode

        raise Exception("Should never reach here")

    def backpropogate(self, node, reward):
        while node is not None:
            node.numVisits += 1
            node.totalReward += reward
            node = node.parent

    def getBestChild(self, node, explorationValue):
        bestValue = float("-inf")
        bestNodes = []
        for child in node.children.values():
            nodeValue = child.totalReward / child.numVisits + explorationValue * math.sqrt(
                2 * math.log(node.numVisits) / child.numVisits)
            if nodeValue > bestValue:
                bestValue = nodeValue
                bestNodes = [child]
            elif nodeValue == bestValue:
                bestNodes.append(child)
        return random.choice(bestNodes)

    def getAction(self, root, bestChild):
        for action, node in root.children.items():
            if node is bestChild:
                return action

狀態類
  1. Action類是動作類,封裝了執行的動作,比如下棋到哪個位置
  2. NaughtsAndCrossesState類是狀態類,要提供以下方法
    1. 維護玩家狀態: currentPlayer
    2. 維護棋盤狀態: board
    3. 提供一個獲得所有可行狀態的方法getPossibleActions
    4. 提供一個執行Action的方法takeAction,並且要更新自己的狀態
    5. 提供一個isTerminal函數,用於判斷遊戲是否結束
    6. 提供一個getReward方法,用於計算獎勵
from __future__ import division

from copy import deepcopy
from mcts import mcts
from functools import reduce
import operator


class NaughtsAndCrossesState(object):
    def __init__(self):
        self.target_num = 3  # 最終目標
        self.board_width = 3
        self.board = [[0] * self.board_width for _ in range(self.board_width)]
        self.currentPlayer = 1

    def getPossibleActions(self):
        possibleActions = []
        for i in range(len(self.board)):
            for j in range(len(self.board[i])):
                if self.board[i][j] == 0:
                    possibleActions.append(Action(player=self.currentPlayer, x=i, y=j))
        return possibleActions

    def takeAction(self, action):
        newState = deepcopy(self)
        newState.board[action.x][action.y] = action.player
        newState.currentPlayer = self.currentPlayer * -1
        return newState

    def isTerminal(self):
        for row in self.board:
            if abs(sum(row)) == self.target_num:
                return True
        for column in list(map(list, zip(*self.board))):
            if abs(sum(column)) == self.target_num:
                return True
        for diagonal in [[self.board[i][i] for i in range(len(self.board))],
                         [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
            if abs(sum(diagonal)) == self.target_num:
                return True
        return reduce(operator.mul, sum(self.board, []), 1)

    def getReward(self):
        for row in self.board:
            if abs(sum(row)) == self.target_num:
                return sum(row) / self.target_num
        for column in list(map(list, zip(*self.board))):
            if abs(sum(column)) == self.target_num:
                return sum(column) / self.target_num
        for diagonal in [[self.board[i][i] for i in range(len(self.board))],
                         [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
            if abs(sum(diagonal)) == self.target_num:
                return sum(diagonal) / self.target_num
        return False


class Action():
    def __init__(self, player, x, y):
        self.player = player
        self.x = x
        self.y = y

    def __str__(self):
        return str((self.x, self.y))

    def __repr__(self):
        return str(self)

    def __eq__(self, other):
        return self.__class__ == other.__class__ and self.x == other.x and self.y == other.y and self.player == other.player

    def __hash__(self):
        return hash((self.x, self.y, self.player))


if __name__ == '__main__':
    import numpy as np

    s = NaughtsAndCrossesState()
    tree = mcts(timeLimit=1000)
    while True:
        # 機器下棋
        action = tree.search(initialState=s)
        s = s.takeAction(action)
        print("ai:", action)
        print(np.array(s.board))
        if s.isTerminal():
            print("ai win")
            break
        # 人下棋
        x, y = list(map(int, input().split()))
        action = Action(-1, x, y)
        s = s.takeAction(action)
        print("人:", action)
        print(np.array(s.board))
        print(s.isTerminal())
        if s.isTerminal():
            print("human win")
            break

我們在實際使用中,只需定義一個合適的State和Action類並實現其方法,就可以應用到mcts中.不要將State和Action耦合在MCTS中,這樣就沒有擴展性了

參考博客

面向初學者的蒙特卡洛樹搜索MCTS詳解及其實現

MCTS蒙特卡洛搜索樹實現井字棋遊戲

蒙特卡洛樹搜索(新手教程)

博客

視頻地址

git代碼地址

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章