隨機取樣的實現

閱讀《編程珠璣》取樣問題,有感,遂Java實現。

需求

程序的輸入包含兩個整數m和n,其中 m <n 。輸出是 0~n-1 範圍內 m 個隨機整數的有序列表,不允許重複。從概率的角度說,我們希望得到沒有重複的有序選擇,其中每個選擇出現的概率相等。

簡單來說,就是從n個樣本中隨機抽取m個。

思路

隨機取樣,大致有兩種思路。僞代碼如下:

// 思路一
while(已抽取樣本數 < 需要抽取的樣本數){
      隨機抽取樣本a
      if(a不在已抽取樣本中){
            將a加入已抽取樣本
            已抽取樣本數++
      }
}
 
// 思路二
將所有樣本順序打亂
按順序取走需要的樣本數

思路一通過循環隨機直至樣本數滿足條件,思路二通過打亂樣本順序的方式取樣。

源碼

用Java代碼實現後,自測在各種情況下,思路一性能都好於思路二。下面是源碼。

經優化後的思路一(性能非常好,所以分享,哈哈~)。
主要優化點:

  • 利用數組的快速定位來校驗某個樣本是否已被抽取;
  • 如果取樣數大於總樣本數的一半,那就隨機抽取其補集(另一小半)。
    /**
     * 隨機取樣
     *
     * @param bound 樣本總數
     * @param count 需要抽取的樣本數
     * @return 返回一個有序數組
     */
    private static int[] getRandomSamples(int bound, int count) {
        if (bound < 1 || count < 1 || bound <= count)
            return null;

        boolean[] fillArray = new boolean[bound];
        for (int i = 0; i < bound; i++) {
            fillArray[i] = false; //用false標示未填充,true表示已填充。
        }

        Random random = new Random();
        int fillCount = 0;
        final int randomNumCount = Math.min(count, bound - count); //隨機填充的數目不超過一半
        while (fillCount < randomNumCount) {
            int num = random.nextInt(bound);
            if (!fillArray[num]) {
                fillArray[num] = true;
                fillCount++;
            }
        }

        int[] samples = new int[count];
        //如果隨機抽取的數量與所需相等,則取該集合;否則取補集。
        if (randomNumCount == count) {
            int index = 0;
            for (int i = 0; i < bound; i++) {
                if (fillArray[i])
                    samples[index++] = i;
            }
        } else {
            //取補集
            int index = 0;
            for (int i = 0; i < bound; i++) {
                if (!fillArray[i])
                    samples[index++] = i;
            }
        }
        return samples;
    }

思路二,調用java默認的洗牌方法來實現,性能不如思路一的實現(常見數據量下耗時大概是上面代碼的2~10倍;對於極大範圍取樣,比如1億樣本里隨機抽取500萬,耗時是上面代碼的100倍)。

    /**
     * 通過洗牌的方式隨機取樣
     */
    private static int[] getRandomSamples2(int bound, int count) {
        if (bound < 1 || count < 1 || bound <= count)
            return null;
        List<Integer> list = new ArrayList<>(bound);
        for (int i = 0; i < bound; i++) {
            list.add(i);
        }
        Collections.shuffle(list);
        int[] samples = new int[count];
        for (int i = 0; i < count; i++) {
            samples[i] = list.get(i);
        }
        return samples;
    }

Gist 隨機取樣Java源碼

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章