https://leetcode.com/problems/random-pick-index/
Given an array of integers with possible duplicates, randomly output the index of a given target number. You can assume that the given target number must exist in the array.
Note:
The array size can be very large. Solution that uses too much extra space will not pass the judge.Example:
int[] nums = new int[] {1,2,3,3,3}; Solution solution = new Solution(nums); // pick(3) should return either index 2, 3, or 4 randomly. Each index should have equal probability of returning. solution.pick(3); // pick(1) should return 0. Since in the array only nums[0] is equal to 1. solution.pick(1);
直觀做法:
用map記錄每個數字和他出現的所有下標,然後用rand找
class Solution {
public:
default_random_engine E;
unordered_map<int, vector<int>> M;
Solution(vector<int>& nums) {
for(int i = 0; i < nums.size(); ++i){
M[nums[i]].push_back(i);
}
}
int pick(int target) {
uniform_int_distribution<int> U(0, M[target].size()-1);
return M[target][U(E)];
}
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(nums);
* int param_1 = obj->pick(target);
*/
更好的做法:
隨機尋找的方法,可證明:
1/k * (1-1/k+1)) * (1-1/k+2)) * ... * (1-1/n)
= 1/k * k/(k+1) * (k+1)/(k+2) * ... * (n-1)/n
= 1/n
class Solution {
public:
vector<int> V;
int len;
Solution(vector<int>& nums):V(nums){
len = nums.size();
static int fast_io = []() { std::ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
}
int pick(int target) {
int count = 0, result = 0;
for(int i = 0; i < len; ++i){
if(V[i] == target){
++count;
if(rand()%count == 0) result = i;
}
}
return result;
}
};
關於隨機性驗證我寫了個小的python腳本:
import random
def get_random():
# print(int(random.random()*65536))
return int(random.random()*65536)
test_cnt = int(input("how much times: "))
num_cnt = int(input("how many numbers: "))
cnt_list = []
for i in range(0, num_cnt):
cnt_list.append(0)
for i in range(0, test_cnt):
cur_result = 0
for i in range(0, num_cnt):
if(get_random() % (i+1) == 0):
cur_result = i
# print(cur_result)
cnt_list[cur_result] += 1
stand = 1.0/num_cnt
print(stand)
for i in range(0, num_cnt):
cur_num = float(cnt_list[i])/test_cnt
print(i, ":\t", cur_num, "\t", cur_num-stand)