背景
我已經做過一些強化學習相關項目,本科的時候也用min-max搜索做過2048,一直覺得2048應該是適合被強化學習解決的,但是查詢發現並沒有比較合適靠譜的實現代碼,於是完成並開源了我的一部分實現工作,供RL learner 參考,github鏈接 https://github.com/YangRui2015/2048_env。
工作
- 修改了https://github.com/rgal/gym-2048的gym封裝的2048環境,增加最大步數和最大非法步數限制,能降低訓練難度,增加info輸出;
- DQN算法實現,訓練和測試,模型保存和加載(使用pytorch);
- logger日誌代碼實現,包括控制檯、txt文件、tensorboard等數據格式的日誌;
DQN實現功能或trick有:
- CNN input or flattened input
- randomly fill buffer first
- soft target replacing
- linear epsilon decay
- clip gradient norm
- Double DQN
- priority experience replay
分析問題
一個典型的深度強化學習問題,主要有以下幾個基本點:
- 狀態、動作、獎勵的設計與表徵;
- 強化學習算法的實現和參數選擇;
- 神經網絡的設計和調參;
在以上幾點完成的基礎上,算法的提升主要有三個方面:
- 提高訓練的穩定性;
- 提高訓練的速度;
- 提高算法的performance;
狀態表徵
環境的狀態輸出是4*4的矩陣,針對這種狀態我們通常選擇flatten成一維向量做全連接輸入或者使用CNN輸入。CNN雖然更有利於提取一些空間的特徵,但是flatten後只要網絡擬合能力足夠也是能夠學習到這些非線性特徵的。實驗表明兩者效果接近,但是flatten最開始階段學習會比CNN快一些,符合我們的預期。
此外,由於狀態矩陣中的值以及獎勵值2~1024甚至更大,直接輸入網絡很容易爆炸,需要對輸入狀態和獎勵值做預處理,這裏簡單的使用了log(x+1)/16實現歸一化。
強化學習算法
使用DQN算法,我也嘗試過A2C、PPO,但是訓練效果都不好,猜想隨機策略在這個問題上表現不如確定性策略。
在實現基本DQN算法上,我還實現了DDQN、Priority DQN(參考莫凡強化學習代碼),以及加上了一些提高訓練穩定性的方法——soft replacing、epsilon decay、clip gradient norm。
參數設計
DQN參數如下:
batch_size = 128
lr = 1e-4
epsilon = 0.15
memory_capacity = int(1e4)
gamma = 0.99
q_network_iteration = 200
soft_update_theta = 0.1
clip_norm_max = 1
train_interval = 5
conv_size = (32, 64) # num filters
fc_size = (512, 128)
代碼實現
DQN需要的網絡模塊如下,主要區別是CNN輸入還是flatten輸入(具體見NN_module.py)。
# CNN網絡
class CNN_Net(nn.Module):
def __init__(self, input_len, output_num, conv_size=(32, 64), fc_size=(1024, 128), out_softmax=False):
super(CNN_Net, self).__init__()
self.input_len = input_len
self.output_num = output_num
self.out_softmax = out_softmax
self.conv1 = nn.Sequential(
nn.Conv2d(1, conv_size[0], kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(32),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(conv_size[0], conv_size[1], kernel_size=3, stride=1, padding=1),
# nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc1 = nn.Linear(conv_size[1] * self.input_len * self.input_len, fc_size[0])
self.fc2 = nn.Linear(fc_size[0], fc_size[1])
self.head = nn.Linear(fc_size[1], self.output_num)
def forward(self, x):
x = x.reshape(-1,1,self.input_len, self.input_len)
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = self.head(x)
if self.out_softmax:
output = F.softmax(output, dim=1) #值函數估計不應該有softmax
return output
# 全連接網絡
class FC_Net(nn.Module):
def __init__(self, input_num, output_num, fc_size=(1024, 128), out_softmax=False):
super(FC_Net, self).__init__()
self.input_num = input_num
self.output_num = output_num
self.out_softmax = out_softmax
self.fc1 = nn.Linear(self.input_num, fc_size[0])
self.fc2 = nn.Linear(fc_size[0], fc_size[1])
self.head = nn.Linear(fc_size[1], self.output_num)
def forward(self, x):
x = x.reshape(-1, self.input_num)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = self.head(x)
if self.out_softmax:
output = F.softmax(output, dim=1) #值函數估計不應該有softmax
return output
DQN代碼實現如下(具體見DQN_agent.py):
class DQN():
batch_size = 128
lr = 1e-4
epsilon = 0.15
memory_capacity = int(1e4)
gamma = 0.99
q_network_iteration = 200
save_path = "./save/"
soft_update_theta = 0.1
clip_norm_max = 1
train_interval = 5
conv_size = (32, 64) # num filters
fc_size = (512, 128)
def __init__(self, num_state, num_action, enable_double=False, enable_priority=True):
super(DQN, self).__init__()
self.num_state = num_state
self.num_action = num_action
self.state_len = int(np.sqrt(self.num_state))
self.enable_double = enable_double
self.enable_priority = enable_priority
self.eval_net, self.target_net = CNN_Net(self.state_len, num_action,self.conv_size, self.fc_size), CNN_Net(self.state_len, num_action, self.conv_size, self.fc_size)
# self.eval_net, self.target_net = FC_Net(self.num_state, self.num_action), FC_Net(self.num_state, self.num_action)
self.learn_step_counter = 0
self.buffer = Buffer(self.num_state, 'priority', self.memory_capacity)
# self.memory = np.zeros((self.memory_capacity, num_state * 2 + 2))
self.initial_epsilon = self.epsilon
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr)
def select_action(self, state, random=False, deterministic=False):
state = torch.unsqueeze(torch.FloatTensor(state), 0)
if not random and np.random.random() > self.epsilon or deterministic: # greedy policy
action_value = self.eval_net.forward(state)
action = torch.max(action_value.reshape(-1,4), 1)[1].data.numpy()
else: # random policy
action = np.random.randint(0,self.num_action)
return action
def store_transition(self, state, action, reward, next_state):
state = state.reshape(-1)
next_state = next_state.reshape(-1)
transition = np.hstack((state, [action, reward], next_state))
self.buffer.store(transition)
# index = self.memory_counter % self.memory_capacity
# self.memory[index, :] = transition
# self.memory_counter += 1
def update(self):
#soft update the parameters
if self.learn_step_counter % self.q_network_iteration ==0 and self.learn_step_counter:
for p_e, p_t in zip(self.eval_net.parameters(), self.target_net.parameters()):
p_t.data = self.soft_update_theta * p_e.data + (1 - self.soft_update_theta) * p_t.data
self.learn_step_counter+=1
#sample batch from memory
if self.enable_priority:
batch_memory, (tree_idx, ISWeights) = self.buffer.sample(self.batch_size)
else:
batch_memory, _ = self.buffer.sample(self.batch_size)
batch_state = torch.FloatTensor(batch_memory[:, :self.num_state])
batch_action = torch.LongTensor(batch_memory[:, self.num_state: self.num_state+1].astype(int))
batch_reward = torch.FloatTensor(batch_memory[:, self.num_state+1: self.num_state+2])
batch_next_state = torch.FloatTensor(batch_memory[:,-self.num_state:])
#q_eval
q_eval_total = self.eval_net(batch_state)
q_eval = q_eval_total.gather(1, batch_action)
q_next = self.target_net(batch_next_state).detach()
if self.enable_double:
q_eval_argmax = q_eval_total.max(1)[1].view(self.batch_size, 1)
q_max = q_next.gather(1, q_eval_argmax).view(self.batch_size, 1)
else:
q_max = q_next.max(1)[0].view(self.batch_size, 1)
q_target = batch_reward + self.gamma * q_max
if self.enable_priority:
abs_errors = (q_target - q_eval.data).abs()
self.buffer.update(tree_idx, abs_errors)
# loss = (torch.FloatTensor(ISWeights) * (q_target - q_eval).pow(2)).mean()
loss = (q_target - q_eval).pow(2).mean() # 可能去掉ISweight更好??
# print(ISWeights)
# print(loss)
# import pdb; pdb.set_trace()
else:
loss = F.mse_loss(q_eval, q_target)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.eval_net.parameters(), self.clip_norm_max)
self.optimizer.step()
return loss
def save(self, path=None, name='dqn_net.pkl'):
path = self.save_path if not path else path
utils.check_path_exist(path)
torch.save(self.eval_net.state_dict(), path + name)
def load(self, path=None, name='dqn_net.pkl'):
path = self.save_path if not path else path
self.eval_net.load_state_dict(torch.load(path + name))
def epsilon_decay(self, episode, total_episode):
self.epsilon = self.initial_epsilon * (1 - episode / total_episode)
其中buffer類實現了普通版和priority版,具體見Buffer_module.py,主函數實現main_dqn.py
實驗結果
CNN輸入
全連接輸入
CNN input + Priority
總結
實驗最好的結果已經到了平均得分6000分,繼續訓練還能繼續增長,但是太花費時間和計算資源,於是我沒有繼續實驗了。整個過程對解決實際強化學習問題能有更好的認識,並且實現通過實驗能更好的理解一些方法或trick提出的原因。
希望能對大家有幫助。