numpy.random.permutation() ,使用numpy隨機打散訓練數據,同時保持訓練數據與標籤的對齊

如果訓練數據之間相關性很大,比如配對是按照從1到9開始的順序,則用這樣的訓練數據訓練時很可能導致訓練的泛華能力不足,所以有必要訓練前把訓練數據打亂,同時還要保持打亂前訓練數據和訓練標籤的對應關係。

numpy.random.permutation(length)用來產生一個隨機序列作爲索引,再使用這個序列從原來的數據集中按照新的隨機順序產生隨機數據集。length 爲訓練數據的個數。

import numpy as np
indices = numpy.random.permutation(data_x.shape[0]) # shape[0]表示第0軸的長度,通常是訓練數據的數量
rand_data_x = data_x[indices]
rand_data_y = data_y[indices] # data_y就是標記(label)


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章