洗牌程序

前兩天微博上有人討論洗牌程序,沒細看內容但感覺似乎有點意思,今天自己嘗試一下。

所謂洗牌程序就是把一個序列的元素位置打亂,這在 Python 裏有一個標準函數:random.shuffle。在開始動手之前我們先簡化描述一下需求:

洗牌後的每個元素隨機出現在每個位置,且概率相同

從這個結論也可以推導出一點:

元素洗牌後的位置與洗牌前無關

第一點是第二點的充分條件,因此測試函數只需要測第一點就夠了。測試方法爲重複執行洗牌函數,統計每個位置每個元素出現的次數。這樣測試函數會輸出一個矩陣,如果我們不考慮針對性搗亂情況的話,測試函數可以簡化爲計算每個位置的值的和,這個和的方差能夠近似體現洗牌函數的正確性。

def test(f, n=10000):
    dataset = [list(range(10)) for i in range(n)]
    for row in dataset:
        f(row)
    stat = []
    for i in range(10):
        stat.append(sum([row[i] for row in dataset]))
    return stat

先測試一下標準函數的結果:

In [72]: test(random.shuffle)
Out[72]: [44909, 44468, 45506, 45184, 45086, 44883, 44574, 45145, 44656, 45589]

然後我們分析一下這個問題:

前提:在本文中不考慮僞隨機數的問題,並且除了 random.random 外不能使用其他函數

每個元素分配一個隨機位置的話,等於說每個元素的分配過程應該是互相獨立的,與現在位置無關的。因此最簡單的方法就是爲每個元素分配一個隨機數,然後按隨機數的值進行排序這樣他們的位置就該是完全隨機的。

def myshuffle(array):
    sort_map = {x: random.random() for x in array}
    new_array = sorted(array, key=lambda x: sort_map[x])
    for i in range(len(array)):
        array[i] = new_array[i]

測試結果:

In [88]: test(myshuffle)
Out[88]: [45243, 44605, 45217, 45030, 45129, 44362, 45246, 45033, 45037, 45098]

據說洗牌程序業界有一個標準的 fisher_yates 算法,翻譯成 Python 是這樣的:

def fisher_yates(array):
    for i in reversed(range(1, len(array))):
        j = int(random() * (i + 1))  # j = random.choice(list(range(i + 1)))
        array[i], array[j] = array[j], array[i]

即從待處理序列中隨機抽出一個元素放到隊尾,然後將待處理序列的尾部邊界向前挪一位(如果 j 的生成不好理解就看註釋的那行等價代碼)。因爲最後剩一個元素的時候沒必要再抽,所以這個算法比上面的 myshuffle 少進行一次 random 運算,而且因爲是直接調換位置,空間消耗小得多,還少了一次排序計算。

測試結果:

In [148]: test(fisher_yates)
Out[148]: [44996, 45089, 45322, 44926, 44888, 45023, 44896, 45407, 44873, 44580]

這個算法也是 Python 內建的 random.shuffle 使用的方法。

那麼最後一個問題,是否還存在其他的算法,使得調用 random.random 的次數更少呢?即,是否存在算法,使序列長度爲 N(N≧3) 時調用隨機數生成函數的次數 k ≦ N-2?

回想隨機數的用法的話,會發現我們通常都是用它來生成一個樣本空間的隨機下標的。即使簡單如 N=2 的情況:

x = [1, 0]

def shuffle_2(x):
	assert len(X) == 2
	if random.random() < 0.5:
		x.reverse()

我們做這個小於 0.5 的判斷,其本質也是在大小爲 2 的空間裏選擇下標。如果要使用一個隨機數洗牌 3 個以上(含)的元素的話,我們就需要構建一個空間,空間裏的每一個元素都含有一種唯一的元素排列形式。即該空間爲序列的全排列。

def wired_shuffle(array):
    all_permutations = some_function(array)
    rand_index = int(len(all_permutations) * random.random())
    new_array = all_permutations[rand_index]
    for i in range(len(array)):
        array[i] = new_array[i]

好,現在問題轉化了。通過這種方式我們節省了隨機函數的調用時間,卻不得不生成一次序列的全排列。這帶來了兩個問題:

  1. 是否存在一種方法可以直接從下標計算出某種排列,而非全部生成一遍
  2. 隨機函數返回的浮點數是有限的,那麼這個算法能處理的序列長度也就是有限的。

其實在現有的隨機函數實現下,問題2 就已經判了這個算法死刑了。但爲了好玩,我們繼續思考一下問題1。上面圖省事沒有去實現這個 some_function,現在不得不先實現一下看看邏輯:

def some_function(array):
    if not array:
        return [[]]
    all_permutations = []
    for i, key in enumerate(array):
        remaining = array[:i] + array[i + 1:]
        all_permutations.extend([[key] + _array for _array in some_function(remaining)])
    return all_permutations

(因爲是洗牌程序用的,所以無需去重,默認每個元素都不一樣就可以了。)

那麼問題1的答案可以是:

def the_permutation(array, index):
    _array = copy.deepcopy(array)
    for i in range(len(_array)):
        left_permutation_count = math.factorial(len(_array) - i - 1)
        j = (index // left_permutation_count)  # 計算這一位的係數
        _array.insert(i, _array.pop(j + i))
        index -= (j * left_permutation_count)
    return _array

好了,現在我們有了一個新的 wired_shuffle:

def new_wired_shuffle(array):
    index = int(math.factorial(len(array)) * random.random())
    new_array = the_permutation(array, index)
    for i in range(len(array)):
        array[i] = new_array[i]

整合簡化一下:

def new_wired_shuffle(array):
    index = int(math.factorial(len(array)) * random.random())
    for i in range(len(array)):
        left_permutation_count = math.factorial(len(array) - i - 1)
        j = (index // left_permutation_count)
        array.insert(i, array.pop(j + i))
        index -= (j * left_permutation_count)

測試一下:

In [171]: test(new_wired_shuffle)
Out[171]: [44967, 45259, 44984, 45141, 44940, 44820, 45165, 44865, 44854, 45005]

這次我們調用 random 的次數縮減到了 1 次,卻增加了 N 次 math.factorial 調用。階乘函數的速度會隨着 N 變大而越來越慢,且 insert/pop 也比直接賦值要慢。所以,這個函數的性能到底怎麼樣呢?

我們拿他和 random.shuffle 對比一下(通過 timeit):

In [200]: for N in (2, 4, 6, 8, 12, 15, 25, 50):
     ...:     print('%d' % N)
     ...:     compare([random.shuffle, new_wired_shuffle], N)
     ...:     
2
             shuffle: 0.020176645999526954
   new_wired_shuffle: 0.032341828000426176
4
             shuffle: 0.03858100000070408
   new_wired_shuffle: 0.04256904300018505
6
             shuffle: 0.056119916998795816
   new_wired_shuffle: 0.047379309000461944
8
             shuffle: 0.07575537699995039
   new_wired_shuffle: 0.06619280999984767
12
             shuffle: 0.09467284899983497
   new_wired_shuffle: 0.09349796399874322
15
             shuffle: 0.11433432200101379
   new_wired_shuffle: 0.11504831499951251
25
             shuffle: 0.1927355459993123
   new_wired_shuffle: 0.21537912900021183
50
             shuffle: 0.35780333500042616
   new_wired_shuffle: 0.535689784999704

N 在 [6, 12] 區間內小勝!耶~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章