Java多線程 - Fork/Join框架原理解析

1、概述

Fork/Join Pool採用優良的設計、代碼實現和硬件原子操作機制等多種思路保證其執行性能。其中包括(但不限於):計算資源共享、高性能隊列、避免僞共享、工作竊取機制等。本文(以及後續文章)試圖和讀者一起分析JDK1.8中Fork/Join Pool的源代碼實現,去理解Fork/Join Pool是怎樣工作的。當然這裏要說明一下,起初本人在決定閱讀Fork/Join歸併計算相關類的源代碼時(ForkJoinPool、WorkQueue、ForkJoinTask、RecursiveTask、ForkJoinWorkerThread等),並不覺得這部分代碼比起LinkedList這樣的類來說有多少難度, 但其中大量使用位運算和位運算技巧,有大量Unsafe原子操作。博主能力有限,確實不能在短時間內將所有代碼一一詳細解讀,所以也希望各位讀者能幫助筆者一同完善。

2.  原理

基本思想

  • ForkJoinPool 的每個工作線程都維護着一個工作隊列WorkQueue),這是一個雙端隊列(Deque),裏面存放的對象是任務ForkJoinTask)。
  • 每個工作線程在運行中產生新的任務(通常是因爲調用了 fork())時,會放入工作隊列的隊尾,並且工作線程在處理自己的工作隊列時,使用的是 LIFO 方式,也就是說每次從隊尾取出任務來執行。
  • 每個工作線程在處理自己的工作隊列同時,會嘗試竊取一個任務(或是來自於剛剛提交到 pool 的任務,或是來自於其他工作線程的工作隊列),竊取的任務位於其他線程的工作隊列的隊首,也就是說工作線程在竊取其他工作線程的任務時,使用的是 FIFO 方式。
  • 在遇到 join() 時,如果需要 join 的任務尚未完成,則會先處理其他任務,並等待其完成。
  • 在既沒有自己的任務,也沒有可以竊取的任務時,進入休眠。

fork

fork() 做的工作只有一件事,既是把任務推入當前工作線程的工作隊列裏。可以參看以下的源代碼:

public final ForkJoinTask<V> fork() {
    Thread t;
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
        ((ForkJoinWorkerThread)t).workQueue.push(this);
    else
        ForkJoinPool.common.externalPush(this);
    return this;
}

join

join() 的工作則複雜得多,也是 join() 可以使得線程免於被阻塞的原因——不像同名的 Thread.join()

  1. 檢查調用 join() 的線程是否是 ForkJoinThread 線程。如果不是(例如 main 線程),則阻塞當前線程,等待任務完成。如果是,則不阻塞。
  2. 查看任務的完成狀態,如果已經完成,直接返回結果。
  3. 如果任務尚未完成,但處於自己的工作隊列內,則完成它。
  4. 如果任務已經被其他的工作線程偷走,則竊取這個小偷的工作隊列內的任務(以 FIFO 方式),執行,以期幫助它早日完成欲 join 的任務。
  5. 如果偷走任務的小偷也已經把自己的任務全部做完,正在等待需要 join 的任務時,則找到小偷的小偷,幫助它完成它的任務。
  6. 遞歸地執行第5步。

將上述流程畫成序列圖的話就是這個樣子:

以上就是 fork() 和 join() 的原理,這可以解釋 ForkJoinPool 在遞歸過程中的執行邏輯,但還有一個問題

最初的任務是 push 到哪個線程的工作隊列裏的?

這就涉及到 submit() 函數的實現方法了

submit

其實除了前面介紹過的每個工作線程自己擁有的工作隊列以外,ForkJoinPool 自身也擁有工作隊列,這些工作隊列的作用是用來接收由外部線程(非 ForkJoinThread 線程)提交過來的任務,而這些工作隊列被稱爲 submitting queue 。

submit() 和 fork() 其實沒有本質區別,只是提交對象變成了 submitting queue 而已(還有一些同步,初始化的操作)。submitting queue 和其他 work queue 一樣,是工作線程”竊取“的對象,因此當其中的任務被一個工作線程成功竊取時,就意味着提交的任務真正開始進入執行階段。

3. 示例

package ThreadTest.demo;

import lombok.Data;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.*;

/**
 * Created by lizq on 2020/3/1.
 */
public class ForkJoinTest {

    public static void main(String[] args) throws ExecutionException, InterruptedException {

        long[] arrs = RandomArr.createLongArr(30, 0, 100);
        Arrays.stream(arrs).forEach(i -> {
            System.out.print(i + " ");
        });
        long rlt = Arrays.stream(arrs).sum();
        System.out.println("sum " + rlt);

        // ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        ForkJoinPool forkJoinPool = new ForkJoinPool(5);
        // ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arrs, 0, arrs.length - 1));
        //  rlt = forkJoinTask.get();
        // System.out.println("sum " + rlt);
        ForkJoinTask forkJoinTask = forkJoinPool.submit(new orderTask(arrs, 0, arrs.length - 1));
        forkJoinTask.get();
        Arrays.stream(arrs).forEach(i -> {
            System.out.print(i + " ");
        });

    }
}

@Data
class SumTask extends RecursiveTask<Long> {

    private long[] arr;
    private int from;
    private int to;

    public SumTask(long[] arr, int from, int to) {
        this.arr = arr;
        this.from = from;
        this.to = to;

    }

    @Override
    protected Long compute() {

        System.out.println("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        Long l = 0l;

        if (to - from < 3) {
            for (int i = from; i <= to; i++) {
                l += this.arr[i];
            }
        } else {
            int m = (from + to) / 2;
            SumTask leftSum = new SumTask(this.arr, from, m);
            SumTask rightSum = new SumTask(this.arr, m + 1, to);
            leftSum.fork();
            rightSum.fork();
            l = leftSum.join() + rightSum.join();
        }
        System.out.println("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        return l;
    }


}

enum Type {
    DESC, ASC;
}


class orderTask extends RecursiveAction {

    private long[] arr;
    private int from;
    private int to;
    private Type type = Type.ASC;


    public orderTask(long[] arr, int from, int to) {
        this.arr = arr;
        this.from = from;
        this.to = to;

    }

    @Override
    protected void compute() {

        //   System.out.print("thread begin : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);
        if (to - from < 2) {
            if (arr[from] > arr[to]) {
                long tmp = this.arr[from];
                this.arr[from] = this.arr[to];
                this.arr[to] = tmp;
            }
        } else {
            int m = (from + to) / 2;
            orderTask leftSum = new orderTask(this.arr, from, m);
            orderTask rightSum = new orderTask(this.arr, m + 1, to);
            leftSum.fork();
            rightSum.fork();
            leftSum.join();
            rightSum.join();
            // 組合排序
            for (int l = from, r = m + 1; l < to && r <= to; l++) {
                if (arr[l] > arr[r]) {
                    long tmp = arr[r];
                    for (int i = r; i > l; ) {
                        arr[i] = arr[--i];
                    }
                    arr[l] = tmp;
                    r++;
                }
            }
        }
        //  System.out.print("thread end : " + Thread.currentThread().getName() + " from= " + from + "; to=" + to);

    }
}

class RandomArr {
    public static long[] createLongArr(int length, int low, int high) {
        if (length < 1) {
            throw new RuntimeException("length < 1");
        }
        return new Random().longs(low, high).limit(length).toArray();
    }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章