深讀源碼-java同步系列之Phaser源碼解析

問題

(1)Phaser是什麼?

(2)Phaser具有哪些特性?

(3)Phaser相對於CyclicBarrier和CountDownLatch的優勢?

簡介

Phaser,翻譯爲階段,它適用於這樣一種場景,一個大任務可以分爲多個階段完成,且每個階段的任務可以多個線程併發執行,但是必須上一個階段的任務都完成了纔可以執行下一個階段的任務。

這種場景雖然使用CyclicBarrier或者CountryDownLatch也可以實現,但是要複雜的多。首先,具體需要多少個階段是可能會變的,其次,每個階段的任務數也可能會變的。相比於CyclicBarrier和CountDownLatch,Phaser更加靈活更加方便。

使用方法

下面我們看一個最簡單的使用案例:

public class PhaserTest {

    public static final int PARTIES = 3;
    public static final int PHASES = 4;

    public static void main(String[] args) {

        Phaser phaser = new Phaser(PARTIES) {
            @Override
            protected boolean onAdvance(int phase, int registeredParties) {
                System.out.println("=======phase: " + phase + " finished=============");
                return super.onAdvance(phase, registeredParties);
            }
        };

        for (int i = 0; i < PARTIES; i++) {
            new Thread(()->{
                for (int j = 0; j < PHASES; j++) {
                    System.out.println(String.format("%s: phase: %d", Thread.currentThread().getName(), j));
                    phaser.arriveAndAwaitAdvance();
                }
            }, "Thread " + i).start();
        }
    }
}

這裏我們定義一個需要4個階段完成的大任務,每個階段需要3個小任務,針對這些小任務,我們分別起3個線程來執行這些小任務,查看輸出結果爲:

Thread 0: phase: 0
Thread 2: phase: 0
Thread 1: phase: 0
=======phase: 0 finished=============
Thread 2: phase: 1
Thread 0: phase: 1
Thread 1: phase: 1
=======phase: 1 finished=============
Thread 1: phase: 2
Thread 0: phase: 2
Thread 2: phase: 2
=======phase: 2 finished=============
Thread 0: phase: 3
Thread 2: phase: 3
Thread 1: phase: 3
=======phase: 3 finished=============

可以看到,每個階段都是三個線程都完成了才進入下一個階段。這是怎麼實現的呢,讓我們一起來學習吧。

原理猜測

根據我們前面學習AQS的原理,大概猜測一下Phaser的實現原理。

首先,需要存儲當前階段phase、當前階段的任務數(參與者)parties、未完成參與者的數量,這三個變量我們可以放在一個變量state中存儲。

其次,需要一個隊列存儲先完成的參與者,當最後一個參與者完成任務時,需要喚醒隊列中的參與者。

嗯,差不多就是這樣子。

結合上面的案例帶入:

初始時當前階段爲0,參與者數爲3個,未完成參與者數爲3;

第一個線程執行到phaser.arriveAndAwaitAdvance();時進入隊列;

第二個線程執行到phaser.arriveAndAwaitAdvance();時進入隊列;

第三個線程執行到phaser.arriveAndAwaitAdvance();時先執行這個階段的總結onAdvance(),再喚醒前面兩個線程繼續執行下一個階段的任務。

嗯,整體能說得通,至於是不是這樣呢,讓我們一起來看源碼吧。

源碼分析

主要內部類

static final class QNode implements ForkJoinPool.ManagedBlocker {
    final Phaser phaser;
    final int phase;
    final boolean interruptible;
    final boolean timed;
    boolean wasInterrupted;
    long nanos;
    final long deadline;
    volatile Thread thread; // nulled to cancel wait
    QNode next;

    QNode(Phaser phaser, int phase, boolean interruptible,
          boolean timed, long nanos) {
        this.phaser = phaser;
        this.phase = phase;
        this.interruptible = interruptible;
        this.nanos = nanos;
        this.timed = timed;
        this.deadline = timed ? System.nanoTime() + nanos : 0L;
        thread = Thread.currentThread();
    }
}

先完成的參與者放入隊列中的節點,這裏我們只需要關注threadnext兩個屬性即可,很明顯這是一個單鏈表,存儲着入隊的線程。

主要屬性

// 狀態變量,用於存儲當前階段phase、參與者數parties、未完成的參與者數unarrived_count
private volatile long state;
// 最多可以有多少個參與者,即每個階段最多有多少個任務
private static final int  MAX_PARTIES     = 0xffff;
// 最多可以有多少階段
private static final int  MAX_PHASE       = Integer.MAX_VALUE;
// 參與者數量的偏移量
private static final int  PARTIES_SHIFT   = 16;
// 當前階段的偏移量
private static final int  PHASE_SHIFT     = 32;
// 未完成的參與者數的掩碼,低16位
private static final int  UNARRIVED_MASK  = 0xffff;      // to mask ints
// 參與者數,中間16位
private static final long PARTIES_MASK    = 0xffff0000L; // to mask longs
// counts的掩碼,counts等於參與者數和未完成的參與者數的'|'操作
private static final long COUNTS_MASK     = 0xffffffffL;
private static final long TERMINATION_BIT = 1L << 63;

// 一次一個參與者完成
private static final int  ONE_ARRIVAL     = 1;
// 增加減少參與者時使用
private static final int  ONE_PARTY       = 1 << PARTIES_SHIFT;
// 減少參與者時使用
private static final int  ONE_DEREGISTER  = ONE_ARRIVAL|ONE_PARTY;
// 沒有參與者時使用
private static final int  EMPTY           = 1;

// 用於求未完成參與者數量
private static int unarrivedOf(long s) {
    int counts = (int)s;
    return (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
}
// 用於求參與者數量(中間16位),注意int的位置
private static int partiesOf(long s) {
    return (int)s >>> PARTIES_SHIFT;
}
// 用於求階段數(高32位),注意int的位置
private static int phaseOf(long s) {
    return (int)(s >>> PHASE_SHIFT);
}
// 已完成參與者的數量
private static int arrivedOf(long s) {
    int counts = (int)s; // 低32位
    return (counts == EMPTY) ? 0 :
        (counts >>> PARTIES_SHIFT) - (counts & UNARRIVED_MASK);
}
// 用於存儲已完成參與者所在的線程,根據當前階段的奇偶性選擇不同的隊列
private final AtomicReference<QNode> evenQ;
private final AtomicReference<QNode> oddQ;

主要屬性爲stateevenQoddQ

(1)state,狀態變量,高32位存儲當前階段phase,中間16位存儲參與者的數量,低16位存儲未完成參與者的數量;

Phaser

(2)evenQ和oddQ,已完成的參與者存儲的隊列,當最後一個參與者完成任務後喚醒隊列中的參與者繼續執行下一個階段的任務,或者結束任務。

構造方法

public Phaser() {
    this(null, 0);
}

public Phaser(int parties) {
    this(null, parties);
}

public Phaser(Phaser parent) {
    this(parent, 0);
}

public Phaser(Phaser parent, int parties) {
    if (parties >>> PARTIES_SHIFT != 0)
        throw new IllegalArgumentException("Illegal number of parties");
    int phase = 0;
    this.parent = parent;
    if (parent != null) {
        final Phaser root = parent.root;
        this.root = root;
        this.evenQ = root.evenQ;
        this.oddQ = root.oddQ;
        if (parties != 0)
            phase = parent.doRegister(1);
    }
    else {
        this.root = this;
        this.evenQ = new AtomicReference<QNode>();
        this.oddQ = new AtomicReference<QNode>();
    }
    // 狀態變量state的存儲分爲三段
    this.state = (parties == 0) ? (long)EMPTY :
        ((long)phase << PHASE_SHIFT) |
        ((long)parties << PARTIES_SHIFT) |
        ((long)parties);
}

構造函數中還有一個parent和root,這是用來構造多層級階段的,不在本文的討論範圍之內,忽略之。

重點還是看state的賦值方式,高32位存儲當前階段phase,中間16位存儲參與者的數量,低16位存儲未完成參與者的數量。

下面我們一起來看看幾個主要方法的源碼:

register()方法

註冊一個參與者,如果調用該方法時,onAdvance()方法正在執行,則該方法等待其執行完畢。

public int register() {
    return doRegister(1);
}
private int doRegister(int registrations) {
    // state應該加的值,注意這裏是相當於同時增加parties和unarrived
    long adjust = ((long)registrations << PARTIES_SHIFT) | registrations;
    final Phaser parent = this.parent;
    int phase;
    for (;;) {
        // state的值
        long s = (parent == null) ? state : reconcileState();
        // state的低32位,也就是parties和unarrived的值
        int counts = (int)s;
        // parties的值
        int parties = counts >>> PARTIES_SHIFT;
        // unarrived的值
        int unarrived = counts & UNARRIVED_MASK;
        // 檢查是否溢出
        if (registrations > MAX_PARTIES - parties)
            throw new IllegalStateException(badRegister(s));
        // 當前階段phase
        phase = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            break;
        // 不是第一個參與者
        if (counts != EMPTY) {                  // not 1st registration
            if (parent == null || reconcileState() == s) {
                // unarrived等於0說明當前階段正在執行onAdvance()方法,等待其執行完畢
                if (unarrived == 0)             // wait out advance
                    root.internalAwaitAdvance(phase, null);
                // 否則就修改state的值,增加adjust,如果成功就跳出循環
                else if (UNSAFE.compareAndSwapLong(this, stateOffset,
                                                   s, s + adjust))
                    break;
            }
        }
        // 是第一個參與者
        else if (parent == null) {              // 1st root registration
            // 計算state的值
            long next = ((long)phase << PHASE_SHIFT) | adjust;
            // 修改state的值,如果成功就跳出循環
            if (UNSAFE.compareAndSwapLong(this, stateOffset, s, next))
                break;
        }
        else {
            // 多層級階段的處理方式
            synchronized (this) {               // 1st sub registration
                if (state == s) {               // recheck under lock
                    phase = parent.doRegister(1);
                    if (phase < 0)
                        break;
                    // finish registration whenever parent registration
                    // succeeded, even when racing with termination,
                    // since these are part of the same "transaction".
                    while (!UNSAFE.compareAndSwapLong
                           (this, stateOffset, s,
                            ((long)phase << PHASE_SHIFT) | adjust)) {
                        s = state;
                        phase = (int)(root.state >>> PHASE_SHIFT);
                        // assert (int)s == EMPTY;
                    }
                    break;
                }
            }
        }
    }
    return phase;
}
// 等待onAdvance()方法執行完畢
// 原理是先自旋一定次數,如果進入下一個階段,這個方法直接就返回了,
// 如果自旋一定次數後還沒有進入下一個階段,則當前線程入隊列,等待onAdvance()執行完畢喚醒
private int internalAwaitAdvance(int phase, QNode node) {
    // 保證隊列爲空
    releaseWaiters(phase-1);          // ensure old queue clean
    boolean queued = false;           // true when node is enqueued
    int lastUnarrived = 0;            // to increase spins upon change
    // 自旋的次數
    int spins = SPINS_PER_ARRIVAL;
    long s;
    int p;
    // 檢查當前階段是否變化,如果變化了說明進入下一個階段了,這時候就沒有必要自旋了
    while ((p = (int)((s = state) >>> PHASE_SHIFT)) == phase) {
        // 如果node爲空,註冊的時候傳入的爲空
        if (node == null) {           // spinning in noninterruptible mode
            // 未完成的參與者數量
            int unarrived = (int)s & UNARRIVED_MASK;
            // unarrived有變化,增加自旋次數
            if (unarrived != lastUnarrived &&
                (lastUnarrived = unarrived) < NCPU)
                spins += SPINS_PER_ARRIVAL;
            boolean interrupted = Thread.interrupted();
            // 自旋次數完了,則新建一個節點
            if (interrupted || --spins < 0) { // need node to record intr
                node = new QNode(this, phase, false, false, 0L);
                node.wasInterrupted = interrupted;
            }
        }
        else if (node.isReleasable()) // done or aborted
            break;
        else if (!queued) {           // push onto queue
            // 節點入隊列
            AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
            QNode q = node.next = head.get();
            if ((q == null || q.phase == phase) &&
                (int)(state >>> PHASE_SHIFT) == phase) // avoid stale enq
                queued = head.compareAndSet(q, node);
        }
        else {
            try {
                // 當前線程進入阻塞狀態,跟調用LockSupport.park()一樣,等待被喚醒
                ForkJoinPool.managedBlock(node);
            } catch (InterruptedException ie) {
                node.wasInterrupted = true;
            }
        }
    }
    
    // 到這裏說明節點所在線程已經被喚醒了
    if (node != null) {
        // 置空節點中的線程
        if (node.thread != null)
            node.thread = null;       // avoid need for unpark()
        if (node.wasInterrupted && !node.interruptible)
            Thread.currentThread().interrupt();
        if (p == phase && (p = (int)(state >>> PHASE_SHIFT)) == phase)
            return abortWait(phase); // possibly clean up on abort
    }
    // 喚醒當前階段阻塞着的線程
    releaseWaiters(phase);
    return p;
}

增加一個參與者總體的邏輯爲:

(1)增加一個參與者,需要同時增加parties和unarrived兩個數值,也就是state的中16位和低16位;

(2)如果是第一個參與者,則嘗試原子更新state的值,如果成功了就退出;

(3)如果不是第一個參與者,則檢查是不是在執行onAdvance(),如果是等待onAdvance()執行完成,如果否則嘗試原子更新state的值,直到成功退出;

(4)等待onAdvance()完成是採用先自旋後進入隊列排隊的方式等待,減少線程上下文切換;

arriveAndAwaitAdvance()方法

當前線程當前階段執行完畢,等待其它線程完成當前階段。

如果當前線程是該階段最後一個到達的,則當前線程會執行onAdvance()方法,並喚醒其它線程進入下一個階段。

public int arriveAndAwaitAdvance() {
    // Specialization of doArrive+awaitAdvance eliminating some reads/paths
    final Phaser root = this.root;
    for (;;) {
        // state的值
        long s = (root == this) ? state : reconcileState();
        // 當前階段
        int phase = (int)(s >>> PHASE_SHIFT);
        if (phase < 0)
            return phase;
        // parties和unarrived的值
        int counts = (int)s;
        // unarrived的值(state的低16位)
        int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
        if (unarrived <= 0)
            throw new IllegalStateException(badArrive(s));
        // 修改state的值
        if (UNSAFE.compareAndSwapLong(this, stateOffset, s,
                                      s -= ONE_ARRIVAL)) {
            // 如果不是最後一個到達的,則調用internalAwaitAdvance()方法自旋或進入隊列等待
            if (unarrived > 1)
                // 這裏是直接返回了,internalAwaitAdvance()方法的源碼見register()方法解析
                return root.internalAwaitAdvance(phase, null);
            
            // 到這裏說明是最後一個到達的參與者
            if (root != this)
                return parent.arriveAndAwaitAdvance();
            // n只保留了state中parties的部分,也就是中16位
            long n = s & PARTIES_MASK;  // base of next state
            // parties的值,即下一次需要到達的參與者數量
            int nextUnarrived = (int)n >>> PARTIES_SHIFT;
            // 執行onAdvance()方法,返回true表示下一階段參與者數量爲0了,也就是結束了
            if (onAdvance(phase, nextUnarrived))
                n |= TERMINATION_BIT;
            else if (nextUnarrived == 0)
                n |= EMPTY;
            else
                // n 加上unarrived的值
                n |= nextUnarrived;
            // 下一個階段等待當前階段加1
            int nextPhase = (phase + 1) & MAX_PHASE;
            // n 加上下一階段的值
            n |= (long)nextPhase << PHASE_SHIFT;
            // 修改state的值爲n
            if (!UNSAFE.compareAndSwapLong(this, stateOffset, s, n))
                return (int)(state >>> PHASE_SHIFT); // terminated
            // 喚醒其它參與者並進入下一個階段
            releaseWaiters(phase);
            // 返回下一階段的值
            return nextPhase;
        }
    }
}

arriveAndAwaitAdvance的大致邏輯爲:

(1)修改state中unarrived部分的值減1;

(2)如果不是最後一個到達的,則調用internalAwaitAdvance()方法自旋或排隊等待;

(3)如果是最後一個到達的,則調用onAdvance()方法,然後修改state的值爲下一階段對應的值,並喚醒其它等待的線程;

(4)返回下一階段的值;

總結

(1)Phaser適用於多階段多任務的場景,每個階段的任務都可以控制得很細;

(2)Phaser內部使用state變量及隊列實現整個邏輯;

(3)state的高32位存儲當前階段phase,中16位存儲當前階段參與者(任務)的數量parties,低16位存儲未完成參與者的數量;

(4)隊列會根據當前階段的奇偶性選擇不同的隊列;

(5)當不是最後一個參與者到達時,會自旋或者進入隊列排隊來等待所有參與者完成任務;

(6)當最後一個參與者完成任務時,會喚醒隊列中的線程並進入下一個階段;

彩蛋

Phaser相對於CyclicBarrier和CountDownLatch的優勢?

答:優勢主要有兩點:

(1)Phaser可以完成多階段,而一個CyclicBarrier或者CountDownLatch一般只能控制一到兩個階段的任務;

(2)Phaser每個階段的任務數量可以控制,而一個CyclicBarrier或者CountDownLatch任務數量一旦確定不可修改。


原文鏈接:https://www.cnblogs.com/tong-yuan/p/11614755.html 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章