問題描述:
本節要求以最壞情況下O(n)的時間複雜度找到長度爲n的數組中第 i 大的數。
解決方案:
《算法導論》上提供了一個算法,該算法實質上是利用了快排中劃分的思想,但其通過一些比較複雜的預處理工作保證了快排劃分的均勻,
並且能夠從理論上證明其最壞情況下的時間複雜度可以達到O(n)。
算法步驟:
1、如圖所示,將n個數分成5個一組,共有⌊n/5⌋組。
2、對⌈n/5⌉組(包括可能不到5個數的那組)的組內數據進行直接插入排序,排序完成之後,圖中白色的數據即爲組內數據的中位數。
並將這⌈n/5⌉箇中位數挑出來,具體做法見後面的代碼。
3、遞歸調用本算法找到這⌈n/5⌉箇中位數的中位數,假設它爲圖中的x。
4、假定n個數是互不相同的(後面會討論一般情況),以x爲樞軸對n個數進行一趟劃分,
使得x左邊的數<x,x右邊的數>x。假設劃分完成之後,x在數組中從左到右排在第k個。
5、如果 i == k,那麼x就是我們要找的數,return x; 即可;
如果 i < k,那麼遞歸調用本算法在x左邊的數中繼續找第 i 個數;
如果 i > k,那麼遞歸調用本算法在x右邊的數中繼續找第 i - k 個數;
爲什麼這個算法是O(n)的?《算法導論》上給予了證明。
證明步驟:
①第1、2、4步都是O(n)的。
其中第2步O(n)是因爲,對n/5組進行組內直接插入排序的時間複雜度是(n/5) * (5^2) = 5n,也就是O(n)。
②根據第1、2、3步我們可以知道,由於n個數是互不相同的,且x是中值的中值,
則圖中陰影部分的數據肯定都比x大,具體來說,比x大的數據至少有:
其中,⌈n/5⌉是組的個數(包括可能不到5個數的那組),-2是去掉x所在的組以及有可能不到5個數的那組。
同理,在x左上角的那部分數據也肯定比x小,並且也至少有這麼多。
所以無論在第5步中是哪種情況,到下一次遞歸時最多有7n/10+6個數。
③假設本算法的最壞時間複雜度是T(n)的,那麼第3步的時間複雜度是T(⌈n/5⌉),第5步的時間複雜度是T(7n/10+6)。
假設對於n<140的情況,找第 i 個數是O(1)的。(後面會解釋爲什麼有這樣的假設,以及爲什麼是140而不是410)
④現在我們只要證明,對於任意的n>0,都成找到一個常數 c 使得T(n) ≤ cn,那麼,這個算法就是O(n)的。
假設式中O(n)項的常數因子爲a,則有:
T(n) ≤ c ⌈n/5⌉ + c(7n/10 + 6) + an
≤ cn/5 + c + 7cn/10 + 6c + an
= 9cn/10 + 7c + an
= cn + (-cn/10 + 7c + an)
如果(-cn/10 + 7c + an) ≤ 0,那麼T(n) ≤ cn。
(-cn/10 + 7c + an) ≤ 0 ⇒ c ≥ 10a(n/(n - 70))
當n≥140時,n/(n - 70) ≤ 2,只要取c = 20a即可使T(n) ≤ cn,亦即本算法是O(n)的。
下面討論如果n個數中有重複數的情況
由於在算法步驟中的第四步以及證明步驟中的第二步都假定個n個數互不相同,這樣才能保證能有3n/10-6個數一定比x小,
同時有3n/10-6個數一定比x大,這樣無論i和k(x在數組中從左到右排在第k個)之間的大小關係是什麼,總能下進入下一次
遞歸時排除掉3n/10-6個數。
而事實上,如果n個數中有重複元素與x相同,比如100個數(中間省略的都是2):
2 2 1 2 2 0 2 …… 2 2 1 2 3
這樣≤2的數在前99個,大於2的數(只有3)在第100個,如果我們要找第5個數,那麼按照前面的算法進入下次遞歸的是前99個數,
這顯然不能滿足“至少除掉3n/10-6個數”的假設。爲了實現在這種情況下仍然是O(n),需要DIY一下劃分算法,具體來說:
將n個數劃分成三部分,第一部分<x,第二部分=x,第三部分>x。按上面的例子劃分結果是(中間省略的都是2):
1 0 1 2 2 2 2 …… 2 2 2 2 3
我們知道從左至右第一個出現的x在第4個,最後出現的x在第99個,那麼:
如果 i >= 4 && i <= 99,那麼x就是我們要找的數,return x; 即可;
如果 i < 4,那麼遞歸調用本算法在x左邊的數中繼續找第 i 個數;
如果 i > 99,那麼遞歸調用本算法在x右邊的數中繼續找第 i - 99 個數;
這樣,就能夠保證進入每次“至少除掉3n/10-6個數”了。
實現代碼:
1 int partitionSpecifyPivot(int a[], int beg, int end, int pivotloc, int *pivotNum) 2 { 3 int pivot = a[pivotloc]; 4 int i = beg - 1; 5 int j = beg - 1; 6 7 for (int k = beg; k <= end; ++k) 8 { 9 if (a[k] <= pivot) 10 { 11 swap(a, ++j, k); 12 } 13 } 14 for (int k = beg; k <= j; ++k) 15 { 16 if (a[k] < pivot) 17 { 18 swap(a, ++i, k); 19 } 20 } 21 //pivotNum is the number of the elements that equal to the pivot 22 if (pivotNum != NULL) 23 { 24 *pivotNum = j - i; 25 } 26 return j; 27 }
1 int ithSmallestLinear(int a[], int beg, int end, int i) 2 { 3 int len = end - beg + 1; 4 5 if (len < 140) 6 { 7 insertionSort(a, beg, end); 8 return beg + i - 1; 9 } 10 //divide the n elements into ⌊n/5⌋ groups and sort each group 11 for (int j = 0; j != len / 5; ++j) 12 { 13 int b = beg + j * 5; 14 int e = b + 4; 15 insertionSort(a, b, e); 16 //move the median of each group to the front of the array 17 swap(a, beg + j, b + 2); 18 } 19 //find the median of each median 20 int pivotLoc = ithSmallestLinear(a, beg, len / 5 - 1, (len / 5 + 1) / 2); 21 //the number of the elements that equal to the pivot 22 int pivotNum = 0; 23 int pivotEndIndex = partitionSpecifyPivot(a, beg, end, pivotLoc, &pivotNum); 24 int n = pivotEndIndex - beg + 1; 25 int m = n - pivotNum + 1; 26 27 if (i >= m && i <= n) 28 { 29 //return the index of the ith smallest element 30 return pivotEndIndex; 31 } 32 else if (i < m) 33 { 34 return ithSmallestLinear(a, beg, pivotEndIndex - pivotNum, i); 35 } 36 else 37 { 38 return ithSmallestLinear(a, pivotEndIndex + 1, end, i - n); 39 } 40 }
測試:
首先測試算法的正確性,代碼如下:
1 #define ARRAY_SIZE 500 2 #define COUNT 10 3 4 int a[ARRAY_SIZE]; 5 int b[ARRAY_SIZE]; 6 7 int main(void) 8 { 9 for (int z = 0; z != COUNT; ++z) 10 { 11 result.open("result.txt"); 12 randArray(a, ARRAY_SIZE, 1, 9999); 13 copyArray(a, 0, b, 0, ARRAY_SIZE); 14 quickSort(a, 0, ARRAY_SIZE - 1); 15 16 for (int i = 1; i <= ARRAY_SIZE; ++i) 17 { 18 int resultStd = a[i - 1]; 19 int resultTest = b[ithSmallestLinear(b, 0, ARRAY_SIZE - 1, i)]; 20 21 // std::cout << "i = " << i << " resultTest = " << resultTest 22 // << " resultStd = " << resultStd << std::endl; 23 if (resultTest != resultStd) 24 { 25 std::cout << "Error" << std::endl; 26 return - 1; 27 } 28 } 29 std::cout << "test " << z << " done." << std::endl; 30 } 31 32 return 0; 33 }
關於該算法最壞情況下能保證O(n)的時間複雜度,只需測試其在數組元素隨機、有序、相同這三種情況下的時間,代碼如下:
(1000000數量級的測試,根據機器實際情況可以把這個調小點)
1 #define ARRAY_SIZE 1000000 2 #define COUNT 10 3 4 int a[ARRAY_SIZE]; 5 int b[ARRAY_SIZE]; 6 int c[ARRAY_SIZE]; 7 8 int main(void) 9 { 10 for (int z = 0; z != COUNT; ++z) 11 { 12 randArray(a, ARRAY_SIZE, 1, ARRAY_SIZE * 2); 13 randArray(c, ARRAY_SIZE, 1, 1); 14 copyArray(a, 0, b, 0, ARRAY_SIZE); 15 quickSort(b, 0, ARRAY_SIZE - 1); 16 17 for (int i = 1; i <= ARRAY_SIZE; ++i) 18 { 19 clock_t start = clock(); 20 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 21 std::cout << clock() - start << "ms "; 22 start = clock(); 23 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 24 std::cout << clock() - start << "ms "; 25 start = clock(); 26 ithSmallestLinear(a, 0, ARRAY_SIZE - 1, i); 27 std::cout << clock() - start << "ms" << std::endl; 28 } 29 std::cout << "test " << z << " done" << std::endl; 30 } 31 return 0; 32 }
測試結果:
可以看到在三種情況下的時間是一個數量級的。
原文地址:http://www.cnblogs.com/goagent/archive/2012/11/01/2747222.html