針對Spark的RDD,API中有一個aggregate函數,本人理解起來費了很大勁,明白之後,mark一下,供以後參考。
首先,Spark文檔中aggregate函數定義如下
def aggregate[U](zeroValue: U)(seqOp: (U, T) ⇒ U, combOp: (U, U) ⇒ U)(implicit arg0: ClassTag[U]): U
Aggregate the elements of each partition, and then the results for all the partitions, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are allowed to modify and return their first argument instead of creating a new U to avoid memory allocation. seqOp操作會聚合各分區中的元素,然後combOp操作把所有分區的聚合結果再次聚合,兩個操作的初始值都是zeroValue. seqOp的操作是遍歷分區中的所有元素(T),第一個T跟zeroValue做操作,結果再作爲與第二個T做操作的zeroValue,直到遍歷完整個分區。combOp操作是把各分區聚合的結果,再聚合。aggregate函數返回一個跟RDD不同類型的值。因此,需要一個操作seqOp來把分區中的元素T合併成一個U,另外一個操作combOp把所有U聚合。
- zeroValue
-
the initial value for the accumulated result of each partition for the
seqOp
operator, and also the initial value for the combine results from different partitions for thecombOp
operator - this will typically be the neutral element (e.g.Nil
for list concatenation or0
for summation) - seqOp
-
an operator used to accumulate results within a partition
- combOp
-
an associative operator used to combine results from different partitions
C:\Windows\System32>scala
Welcome to Scala 2.11.8 (Java HotSpot(TM) Client VM, Java 1.8.0_91).
Type in expressions for evaluation. Or try :help.
scala> val rdd = List(1,2,3,4,5,6,7,8,9)
rdd: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9)
scala> rdd.par.aggregate((0,0))(
(acc,number) => (acc._1 + number, acc._2 + 1),
(par1,par2) => (par1._1 + par2._1, par1._2 + par2._2)
)
res0: (Int, Int) = (45,9)
scala> res0._1 / res0._2
res1: Int = 5
過程大概這樣:
首先,初始值是(0,0),這個值在後面2步會用到。
然後,(acc,number) => (acc._1 + number, acc._2 + 1),number即是函數定義中的T,這裏即是List中的元素。所以acc._1 + number, acc._2 + 1的過程如下。
1. 0+1, 0+1
2. 1+2, 1+1
3. 3+3, 2+1
4. 6+4, 3+1
5. 10+5, 4+1
6. 15+6, 5+1
7. 21+7, 6+1
8. 28+8, 7+1
9. 36+9, 8+1
結果即是(45,9)。這裏演示的是單線程計算過程,實際Spark執行中是分佈式計算,可能會把List分成多個分區,假如3個,p1(1,2,3,4),p2(5,6,7,8),p3(9),經過計算各分區的的結果(10,4),(26,4),(9,1),這樣,執行(par1,par2) => (par1._1 + par2._1, par1._2 + par2._2)就是(10+26+9,4+4+1)即(45,9).再求平均值就簡單了。