2048遊戲DQN實驗

背景

我已經做過一些強化學習相關項目,本科的時候也用min-max搜索做過2048,一直覺得2048應該是適合被強化學習解決的,但是查詢發現並沒有比較合適靠譜的實現代碼,於是完成並開源了我的一部分實現工作,供RL learner 參考,github鏈接 https://github.com/YangRui2015/2048_env
在這裏插入圖片描述

工作

  1. 修改了https://github.com/rgal/gym-2048的gym封裝的2048環境,增加最大步數和最大非法步數限制,能降低訓練難度,增加info輸出;
  2. DQN算法實現,訓練和測試,模型保存和加載(使用pytorch);
  3. logger日誌代碼實現,包括控制檯、txt文件、tensorboard等數據格式的日誌;

DQN實現功能或trick有:

  • CNN input or flattened input
  • randomly fill buffer first
  • soft target replacing
  • linear epsilon decay
  • clip gradient norm
  • Double DQN
  • priority experience replay

分析問題

一個典型的深度強化學習問題,主要有以下幾個基本點:

  1. 狀態、動作、獎勵的設計與表徵;
  2. 強化學習算法的實現和參數選擇;
  3. 神經網絡的設計和調參;

在以上幾點完成的基礎上,算法的提升主要有三個方面:

  1. 提高訓練的穩定性;
  2. 提高訓練的速度;
  3. 提高算法的performance;

狀態表徵

環境的狀態輸出是4*4的矩陣,針對這種狀態我們通常選擇flatten成一維向量做全連接輸入或者使用CNN輸入。CNN雖然更有利於提取一些空間的特徵,但是flatten後只要網絡擬合能力足夠也是能夠學習到這些非線性特徵的。實驗表明兩者效果接近,但是flatten最開始階段學習會比CNN快一些,符合我們的預期。

此外,由於狀態矩陣中的值以及獎勵值2~1024甚至更大,直接輸入網絡很容易爆炸,需要對輸入狀態和獎勵值做預處理,這裏簡單的使用了log(x+1)/16實現歸一化。

強化學習算法

使用DQN算法,我也嘗試過A2C、PPO,但是訓練效果都不好,猜想隨機策略在這個問題上表現不如確定性策略。
在實現基本DQN算法上,我還實現了DDQN、Priority DQN(參考莫凡強化學習代碼),以及加上了一些提高訓練穩定性的方法——soft replacing、epsilon decay、clip gradient norm。

參數設計

DQN參數如下:

batch_size = 128
lr = 1e-4
epsilon = 0.15
memory_capacity = int(1e4)
gamma = 0.99
q_network_iteration = 200
soft_update_theta = 0.1
clip_norm_max = 1
train_interval = 5
conv_size = (32, 64) # num filters
fc_size = (512, 128)

代碼實現

DQN需要的網絡模塊如下,主要區別是CNN輸入還是flatten輸入(具體見NN_module.py)。

# CNN網絡
class CNN_Net(nn.Module):
    def __init__(self, input_len, output_num, conv_size=(32, 64), fc_size=(1024, 128), out_softmax=False):
        super(CNN_Net, self).__init__()
        self.input_len = input_len
        self.output_num = output_num
        self.out_softmax = out_softmax 

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, conv_size[0], kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(conv_size[0], conv_size[1], kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc1 = nn.Linear(conv_size[1] * self.input_len * self.input_len, fc_size[0])
        self.fc2 = nn.Linear(fc_size[0], fc_size[1])
        self.head = nn.Linear(fc_size[1], self.output_num)

    def forward(self, x):
        x = x.reshape(-1,1,self.input_len, self.input_len)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        output = self.head(x)
        if self.out_softmax:
            output = F.softmax(output, dim=1)   #值函數估計不應該有softmax
        return output


# 全連接網絡
class FC_Net(nn.Module):
    def __init__(self, input_num, output_num, fc_size=(1024, 128), out_softmax=False):
        super(FC_Net, self).__init__()
        self.input_num = input_num
        self.output_num = output_num
        self.out_softmax = out_softmax 

        self.fc1 = nn.Linear(self.input_num, fc_size[0])
        self.fc2 = nn.Linear(fc_size[0], fc_size[1])
        self.head = nn.Linear(fc_size[1], self.output_num)

    def forward(self, x):
        x = x.reshape(-1, self.input_num)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        output = self.head(x)
        if self.out_softmax:
            output = F.softmax(output, dim=1)   #值函數估計不應該有softmax
        return output

DQN代碼實現如下(具體見DQN_agent.py):

class DQN():
    batch_size = 128
    lr = 1e-4
    epsilon = 0.15   
    memory_capacity =  int(1e4)
    gamma = 0.99
    q_network_iteration = 200
    save_path = "./save/"
    soft_update_theta = 0.1
    clip_norm_max = 1
    train_interval = 5
    conv_size = (32, 64)   # num filters
    fc_size = (512, 128)

    def __init__(self, num_state, num_action, enable_double=False, enable_priority=True):
        super(DQN, self).__init__()
        self.num_state = num_state
        self.num_action = num_action
        self.state_len = int(np.sqrt(self.num_state))
        self.enable_double = enable_double
        self.enable_priority = enable_priority

        self.eval_net, self.target_net = CNN_Net(self.state_len, num_action,self.conv_size, self.fc_size), CNN_Net(self.state_len, num_action, self.conv_size, self.fc_size)
        # self.eval_net, self.target_net = FC_Net(self.num_state, self.num_action), FC_Net(self.num_state, self.num_action)

        self.learn_step_counter = 0
        self.buffer = Buffer(self.num_state, 'priority', self.memory_capacity)
        # self.memory = np.zeros((self.memory_capacity, num_state * 2 + 2))     
        self.initial_epsilon = self.epsilon
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)


    def select_action(self, state, random=False, deterministic=False):
        state = torch.unsqueeze(torch.FloatTensor(state), 0) 
        if not random and np.random.random() > self.epsilon or deterministic:  # greedy policy
            action_value = self.eval_net.forward(state)
            action = torch.max(action_value.reshape(-1,4), 1)[1].data.numpy()
        else: # random policy
            action = np.random.randint(0,self.num_action)
        return action


    def store_transition(self, state, action, reward, next_state):
        state = state.reshape(-1)
        next_state = next_state.reshape(-1)

        transition = np.hstack((state, [action, reward], next_state))
        self.buffer.store(transition)
        # index = self.memory_counter % self.memory_capacity
        # self.memory[index, :] = transition
        # self.memory_counter += 1


    def update(self):
        #soft update the parameters
        if self.learn_step_counter % self.q_network_iteration ==0 and self.learn_step_counter:
            for p_e, p_t in zip(self.eval_net.parameters(), self.target_net.parameters()):
                p_t.data = self.soft_update_theta * p_e.data + (1 - self.soft_update_theta) * p_t.data
                
        self.learn_step_counter+=1

        #sample batch from memory
        if self.enable_priority:
            batch_memory, (tree_idx, ISWeights) = self.buffer.sample(self.batch_size)
        else:
            batch_memory, _ = self.buffer.sample(self.batch_size)

        batch_state = torch.FloatTensor(batch_memory[:, :self.num_state])
        batch_action = torch.LongTensor(batch_memory[:, self.num_state: self.num_state+1].astype(int))
        batch_reward = torch.FloatTensor(batch_memory[:, self.num_state+1: self.num_state+2])
        batch_next_state = torch.FloatTensor(batch_memory[:,-self.num_state:])

        #q_eval
        q_eval_total = self.eval_net(batch_state)
        q_eval = q_eval_total.gather(1, batch_action)
        q_next = self.target_net(batch_next_state).detach()

        if self.enable_double:
            q_eval_argmax = q_eval_total.max(1)[1].view(self.batch_size, 1)
            q_max = q_next.gather(1, q_eval_argmax).view(self.batch_size, 1)
        else:
            q_max = q_next.max(1)[0].view(self.batch_size, 1)

        q_target = batch_reward + self.gamma * q_max

        if self.enable_priority:
            abs_errors = (q_target - q_eval.data).abs()
            self.buffer.update(tree_idx, abs_errors)
            # loss = (torch.FloatTensor(ISWeights) * (q_target - q_eval).pow(2)).mean()   
            loss = (q_target - q_eval).pow(2).mean() # 可能去掉ISweight更好??

            
            # print(ISWeights)
            # print(loss)

            # import pdb; pdb.set_trace()
        else:
            loss = F.mse_loss(q_eval, q_target)
        

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.eval_net.parameters(), self.clip_norm_max)
        self.optimizer.step()

        return loss

    
    def save(self, path=None, name='dqn_net.pkl'):
        path = self.save_path if not path else path
        utils.check_path_exist(path)
        torch.save(self.eval_net.state_dict(), path + name)

    def load(self, path=None, name='dqn_net.pkl'):
        path = self.save_path if not path else path
        self.eval_net.load_state_dict(torch.load(path + name))


    def epsilon_decay(self, episode, total_episode):
        self.epsilon = self.initial_epsilon * (1 - episode / total_episode)

其中buffer類實現了普通版和priority版,具體見Buffer_module.py,主函數實現main_dqn.py

實驗結果

CNN輸入

在這裏插入圖片描述

全連接輸入

在這裏插入圖片描述在這裏插入圖片描述

CNN input + Priority

在這裏插入圖片描述
在這裏插入圖片描述

總結

實驗最好的結果已經到了平均得分6000分,繼續訓練還能繼續增長,但是太花費時間和計算資源,於是我沒有繼續實驗了。整個過程對解決實際強化學習問題能有更好的認識,並且實現通過實驗能更好的理解一些方法或trick提出的原因。

希望能對大家有幫助。

發佈了4 篇原創文章 · 獲贊 4 · 訪問量 269
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章