1、概述
Fork/Join Pool採用優良的設計、代碼實現和硬件原子操作機制等多種思路保證其執行性能。其中包括(但不限於):計算資源共享、高性能隊列、避免僞共享、工作竊取機制等。本文(以及後續文章)試圖和讀者一起分析JDK1.8中Fork/Join Pool的源代碼實現,去理解Fork/Join Pool是怎樣工作的。當然這裏要說明一下,起初本人在決定閱讀Fork/Join歸併計算相關類的源代碼時(ForkJoinPool、WorkQueue、ForkJoinTask、RecursiveTask、ForkJoinWorkerThread等),並不覺得這部分代碼比起LinkedList這樣的類來說有多少難度, 但其中大量使用位運算和位運算技巧,有大量Unsafe原子操作。博主能力有限,確實不能在短時間內將所有代碼一一詳細解讀,所以也希望各位讀者能幫助筆者一同完善。
2. 原理
基本思想
ForkJoinPool
的每個工作線程都維護着一個工作隊列(WorkQueue
),這是一個雙端隊列(Deque),裏面存放的對象是任務(ForkJoinTask
)。- 每個工作線程在運行中產生新的任務(通常是因爲調用了
fork()
)時,會放入工作隊列的隊尾,並且工作線程在處理自己的工作隊列時,使用的是 LIFO 方式,也就是說每次從隊尾取出任務來執行。 - 每個工作線程在處理自己的工作隊列同時,會嘗試竊取一個任務(或是來自於剛剛提交到 pool 的任務,或是來自於其他工作線程的工作隊列),竊取的任務位於其他線程的工作隊列的隊首,也就是說工作線程在竊取其他工作線程的任務時,使用的是 FIFO 方式。
- 在遇到
join()
時,如果需要 join 的任務尚未完成,則會先處理其他任務,並等待其完成。 - 在既沒有自己的任務,也沒有可以竊取的任務時,進入休眠。
fork
fork()
做的工作只有一件事,既是把任務推入當前工作線程的工作隊列裏。可以參看以下的源代碼:
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
join
join()
的工作則複雜得多,也是 join()
可以使得線程免於被阻塞的原因——不像同名的 Thread.join()
。
- 檢查調用
join()
的線程是否是 ForkJoinThread 線程。如果不是(例如 main 線程),則阻塞當前線程,等待任務完成。如果是,則不阻塞。 - 查看任務的完成狀態,如果已經完成,直接返回結果。
- 如果任務尚未完成,但處於自己的工作隊列內,則完成它。
- 如果任務已經被其他的工作線程偷走,則竊取這個小偷的工作隊列內的任務(以 FIFO 方式),執行,以期幫助它早日完成欲 join 的任務。
- 如果偷走任務的小偷也已經把自己的任務全部做完,正在等待需要 join 的任務時,則找到小偷的小偷,幫助它完成它的任務。
- 遞歸地執行第5步。
將上述流程畫成序列圖的話就是這個樣子:
以上就是 fork()
和 join()
的原理,這可以解釋 ForkJoinPool 在遞歸過程中的執行邏輯,但還有一個問題
最初的任務是 push 到哪個線程的工作隊列裏的?
這就涉及到 submit()
函數的實現方法了
submit
其實除了前面介紹過的每個工作線程自己擁有的工作隊列以外,ForkJoinPool
自身也擁有工作隊列,這些工作隊列的作用是用來接收由外部線程(非 ForkJoinThread
線程)提交過來的任務,而這些工作隊列被稱爲 submitting queue 。
submit()
和 fork()
其實沒有本質區別,只是提交對象變成了 submitting queue 而已(還有一些同步,初始化的操作)。submitting queue 和其他 work queue 一樣,是工作線程”竊取“的對象,因此當其中的任務被一個工作線程成功竊取時,就意味着提交的任務真正開始進入執行階段。
3. 示例
package ThreadTest.demo;
import lombok.Data;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.*;
/**
* Created by lizq on 2020/3/1.
*/
public class ForkJoinTest {
public static void main(String[] args) throws ExecutionException, InterruptedException {
long[] arrs = RandomArr.createLongArr(30, 0, 100);
Arrays.stream(arrs).forEach(i -> {
System.out.print(i + " ");
});
long rlt = Arrays.stream(arrs).sum();
System.out.println("sum " + rlt);
// ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
ForkJoinPool forkJoinPool = new ForkJoinPool(5);
// ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arrs, 0, arrs.length - 1));
// rlt = forkJoinTask.get();
// System.out.println("sum " + rlt);
ForkJoinTask forkJoinTask = forkJoinPool.submit(new orderTask(arrs, 0, arrs.length - 1));
forkJoinTask.get();
Arrays.stream(arrs).forEach(i -> {
System.out.print(i + " ");
});
}
}
@Data
class SumTask extends RecursiveTask<Long> {
private long[] arr;
private int from;
private int to;
public SumTask(long[] arr, int from, int to) {
this.arr = arr;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
System.out.println("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
Long l = 0l;
if (to - from < 3) {
for (int i = from; i <= to; i++) {
l += this.arr[i];
}
} else {
int m = (from + to) / 2;
SumTask leftSum = new SumTask(this.arr, from, m);
SumTask rightSum = new SumTask(this.arr, m + 1, to);
leftSum.fork();
rightSum.fork();
l = leftSum.join() + rightSum.join();
}
System.out.println("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
return l;
}
}
enum Type {
DESC, ASC;
}
class orderTask extends RecursiveAction {
private long[] arr;
private int from;
private int to;
private Type type = Type.ASC;
public orderTask(long[] arr, int from, int to) {
this.arr = arr;
this.from = from;
this.to = to;
}
@Override
protected void compute() {
// System.out.print("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
if (to - from < 2) {
if (arr[from] > arr[to]) {
long tmp = this.arr[from];
this.arr[from] = this.arr[to];
this.arr[to] = tmp;
}
} else {
int m = (from + to) / 2;
orderTask leftSum = new orderTask(this.arr, from, m);
orderTask rightSum = new orderTask(this.arr, m + 1, to);
leftSum.fork();
rightSum.fork();
leftSum.join();
rightSum.join();
// 組合排序
for (int l = from, r = m + 1; l < to && r <= to; l++) {
if (arr[l] > arr[r]) {
long tmp = arr[r];
for (int i = r; i > l; ) {
arr[i] = arr[--i];
}
arr[l] = tmp;
r++;
}
}
}
// System.out.print("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
}
}
class RandomArr {
public static long[] createLongArr(int length, int low, int high) {
if (length < 1) {
throw new RuntimeException("length < 1");
}
return new Random().longs(low, high).limit(length).toArray();
}
}