java併發-併發工具類

CountDownLatch

java.util.concurrent.CountDownLatch 發令槍,允許一個或多個線程等待其他線程完成操作
主線程需要等待所有的子線程執行完後進行彙總,join方法可以實現這一點,但是不夠靈活。

使用

public class CountDownLatchTest {
	// 計數器
    private static CountDownLatch countDownLatch = new CountDownLatch(2);
    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        executorService.submit(()->{
            try {
                TimeUnit.SECONDS.sleep(3);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }finally { 
            	// 計數器減一
                countDownLatch.countDown();
            }
            System.out.println("Thread A Over");
        });
        executorService.submit(()->{
            try {
                TimeUnit.SECONDS.sleep(2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }finally {
                countDownLatch.countDown();
            }
            System.out.println("Thread B Over");
        });
        // 等待子線池執行完
        countDownLatch.await();
        System.out.println("All Child thread Over!");
        executorService.shutdown();
    }
}

源碼探究

public class CountDownLatch {
	 private static final class Sync extends AbstractQueuedSynchronizer {
	    private static final long serialVersionUID = 4982264981922014374L;
        Sync(int count) {
            setState(count);
        }
        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            // 循環進行CAS,直到當前線程成功完成CAS使計數器值(state)減1,並更新到state
            for (;;) {
                int c = getState();
                // 如果當前狀態值爲0則直接返回
                if (c == 0)
                    return false;
                // 使用CAS讓計數器值減一    
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    // 內部使用AQS實現 
	private final Sync sync;
	// 構造函數 傳入計數器
   	public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        // 把計數器器的值賦予AQS的狀態變量state 
        this.sync = new Sync(count);
    }
    // 線程調用await方法,當前線程會被阻塞,直到下面情況發生才返回
    // 1. 所有線程調用countDown方法,計數器的值爲0
    // 2. 其他線程調用用當前線程的interrupte()方法中斷了當前線程,當前線程拋出InterruptedException返回
     public void await() throws InterruptedException {
     	// 委託sysnc調用AOS的acquireSharedInterruptibly方法
        sync.acquireSharedInterruptibly(1);
    }
    // 帶有超時時間阻塞等待方法
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
    // 計算器的值遞減,遞減爲0則喚醒所有調用await方法而被阻塞的線程
    public void countDown() {
        sync.releaseShared(1);
    }
    // 返回當前計算器值
    public long getCount() {
        return sync.getCount();
    }
}

總結: 相比使用join方法實現線程間同步,CountDownLatch更具有靈活性和方便性,CountDownLatch使用AQS實現的,使用AQS的狀態變量來存放計數器的值,構造函數初始化狀態值(計數器值),當線程調用await()方法後當前線程會被放入AQS阻塞隊列等待計數器爲0在返回,多個線程調用countDown()時原子(cas)遞減AQS的狀態值,計數器值減1,當計數器值變爲0是,當前線程調用AQS的doReleaseShared方法激活由調用await()而被阻塞的線程

CyclicBarrier

java.util.concurrent.CyclicBarrier 可循環(Cyclic)的屏障(Barrier)。迴環屏障,讓一組線程到達屏障(屏障點/同步點)時,屏障纔會打開,所有被屏障攔截的線程纔會繼續工作

使用

import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class CyclicBarrierTest {
    private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,()->{
        System.out.println(Thread.currentThread()+"=====達到屏障點執行=");
        // todo something
    });

    public static void main(String[] args) {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 線程A
        executorService.submit(()->{
            try {
                System.out.println(Thread.currentThread()+"step1");
                cyclicBarrier.await();
                System.out.println(Thread.currentThread()+"step2");
                cyclicBarrier.await();
                System.out.println(Thread.currentThread()+"step3");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
        // 線程B
        executorService.submit(()->{
            try {
                System.out.println(Thread.currentThread()+"step1");
                cyclicBarrier.await();
                System.out.println(Thread.currentThread()+"step2");
                cyclicBarrier.await();
                System.out.println(Thread.currentThread()+"step3");
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
        executorService.shutdown();
    }
}

注意:

  1. 對於指定計數器值parties,若由於某種原因,沒有足夠的線程調用CyclicBarrier的await,則所有調用await的線程都會被阻塞
  2. CyclicBarrier也可以調用await(timeout,unit)設置超時時間,在設定時間內,如果沒有足夠線程到達,則解除阻塞狀態,繼續工作
  3. 通過reset重置計數,會使得進入await的線程出現BrokenBarrierException
    4)如果採用是CyclicBarrier(int parties,Runnable barrierAction)構造方法,執行barrierAction操作的是最後一個到達的線程

源碼探究

CyclicBarrier 基於ReetrantLockCondition實現,CyclicBarrier 可以有不止一個柵欄,因爲它的柵欄(Barrier)可以重複使用(Cyclic)

public class CyclicBarrier {
 
    private static class Generation { 
    	//記錄當前屏障是否被打破 由於在鎖內使用,所以不需要申明volatile
        boolean broken = false;
    }

    /** The lock for guarding barrier entry */
    private final ReentrantLock lock = new ReentrantLock();
    /** Condition to wait on until tripped */
    // lock的條件變量支持線程間使用await和sigal操作進行同步
    private final Condition trip = lock.newCondition();
    /** The number of parties  
     *  記錄線程的個數,表示多少個線程調用await後,所有線程纔會衝破屏障繼續往下運行
     * */
    private final int parties;
    /* The command to run when tripped */
    // 當所有的線程到達了屏障點,最後一個線程執行
    private final Runnable barrierCommand;
    /** The current generation */
    private Generation generation = new Generation();

    /** count一開始等於parties,每當線程調用await方法就遞減1,當count爲0就表示所有線程到了屏障點*/
    private int count;

    /**
     * Updates state on barrier trip and wakes up everyone.
     * Called only while holding lock.
     */
    private void nextGeneration() {
        // signal completion of last generation 喚醒條件隊列裏面阻塞線程
        trip.signalAll();
        // set up next generation 重置CyclicBarrier
        count = parties;
        generation = new Generation();
    }

  
    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }
	// dowait實現了CyclicBarrier的核心功能 
	// timed 是否設置超時時間  
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            final Generation g = generation;
            if (g.broken)
                throw new BrokenBarrierException();
            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            int index = --count;
      		// (1) index==0則說明所有線程都到了屏障點,此時執行初始化時傳遞的任務時傳遞任務
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    // (2) 執行任務
                    if (command != null)
                        command.run();
                    ranAction = true;
                    // (3) 激活其他因調用await方法而被阻塞的線程,並重置CyclicBarrier
                    nextGeneration();
                    return 0;
                } finally {
                    if (!ranAction)
                        breakBarrier();
                }
            }
            // loop until tripped, broken, interrupted, or timed out
            // (4) 如果index != 0
            for (;;) {
                try {
                	// (5) 沒有設置超時時間
                    if (!timed)
                        trip.await();
                    // (6) 設置了超時時間    
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();

                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }

    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

    public int getParties() {
        return parties;
    }
    // 阻塞方法
    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }
	// 待超時時間的阻塞方法 
	// parties個線程都調用了await()方法,也就是線程都到了屏障點,這時候返回true;
	// 設置的超時時間到了後返回false;
	// 其他線程調用當前線程的interrupt()方法中斷了當前線程,則當前線程會拋出InterruptedException異常然後返回;
	// 與當前屏障點關聯的Generation對象的broken標誌被設置爲true時,會拋出BrokenBarrierException異常,然後返回
    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

    public boolean isBroken() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return generation.broken;
        } finally {
            lock.unlock();
        }
    }
    // 重置Barrier
    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier();   // break the current generation
            nextGeneration(); // start a new generation
        } finally {
            lock.unlock();
        }
    }
    public int getNumberWaiting() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            return parties - count;
        } finally {
            lock.unlock();
        }
    }
}

Semaphore

Semaphore(信號量)是用來控制同時訪問特定資源的線程數量,它通過協調各個線程,以保證合理的使用公共資源

使用

場景: Semaphore可以用於做流量控制,特別是公用資源有限的應用場景,比如數據庫連接。

import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

public class SemaphoreTest {
    static class Car extends Thread {
        private int num;
        private Semaphore semaphore;
        public Car(int num, Semaphore semaphore) {
            this.num = num;
            this.semaphore = semaphore;
        }
        @Override
        public void run() {
            try {
                // 獲得一個令牌,如果拿不到令牌,就會阻塞
                semaphore.acquire();
                System.out.println("第"+num+" 搶佔一個車位");
                TimeUnit.SECONDS.sleep(2);
                System.out.println("第"+num+" 開走");
                semaphore.release();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
    public static void main(String[] args) {
        Semaphore semaphore = new Semaphore(5);
//        ExecutorService executorService = Executors.newFixedThreadPool(5);
        for (int i = 0; i < 10; i++) {
            new Car(i,semaphore).start();
//            executorService.submit(new Car(i,semaphore));
        }
//        executorService.shutdown();
    }
}

源碼探究

public class Semaphore implements java.io.Serializable {
    private static final long serialVersionUID = -3222578661600680210L;
    private final Sync sync;
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;
        Sync(int permits) {
            setState(permits);
        }
        final int getPermits() {
            return getState();
        }
        final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
               // 獲取當前信號量值    
                int available = getState();
                // 計算當前剩餘值
                int remaining = available - acquires;
                // 如果剩餘值小於0說明當前信號量個數不滿足需求
                // 大於0且CAS操作成功,返回剩餘值
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
        protected final boolean tryReleaseShared(int releases) {
            for (;;) {
                int current = getState();
                int next = current + releases;
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                if (compareAndSetState(current, next))
                    return true;
            }
        }
        final void reducePermits(int reductions) {
            for (;;) {
                int current = getState();
                // 不允許縮減值爲負數
                int next = current - reductions;
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                if (compareAndSetState(current, next)) // CAS 設置縮減後的許可證數量
                    return;
            }
        }
        final int drainPermits() {
            for (;;) {
                int current = getState();
                // 直接把剩餘的許可證數量設置爲0
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;
        NonfairSync(int permits) {
            super(permits);
        }
        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }
	// 公平策略
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;
        FairSync(int permits) {
            super(permits);
        }
        protected int tryAcquireShared(int acquires) {
            for (;;) {
                // 如果當前線程不位於對頭,則阻塞	
            	// hasQueuedPredecessors 來保證公平性
                if (hasQueuedPredecessors())
                    return -1;
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }
	
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }
    
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
	// 獲取一個信號量值,未獲取到會阻塞
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
    public void acquireUninterruptibly() {
        sync.acquireShared(1);
    }
    public boolean tryAcquire() {
        return sync.nonfairTryAcquireShared(1) >= 0;
    }

    public boolean tryAcquire(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public void release() {
        sync.releaseShared(1);
    }
	
    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }

    public void acquireUninterruptibly(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireShared(permits);
    }

    public boolean tryAcquire(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.nonfairTryAcquireShared(permits) >= 0;
    }
    public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
        throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
    }
    // 釋放信號量
    public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }
	// 返回信號量中對當前可用的許可證數
    public int availablePermits() {
        return sync.getPermits();
    }
	// 獲取立即可用的所有許可證個數,並將可用許可證置0	
    public int drainPermits() {
        return sync.drainPermits();
    }

    protected void reducePermits(int reduction) {
        if (reduction < 0) throw new IllegalArgumentException();
        sync.reducePermits(reduction);
    }
	// 是否是公平策略
    public boolean isFair() {
        return sync instanceof FairSync;
    }
	// 是否有線程正在等待獲取許可證
    public final boolean hasQueuedThreads() {
        return sync.hasQueuedThreads();
    }
    // 返回正則等待許可證的線程數
    public final int getQueueLength() {
        return sync.getQueueLength();
    }
    // 返回所有等待許可證的線程集合
    protected Collection<Thread> getQueuedThreads() {
        return sync.getQueuedThreads();
    }
    public String toString() {
        return super.toString() + "[Permits = " + sync.getPermits() + "]";
    }
}

Exchanger

Exchanger(交換者)是一個用於線程間協作的工具類。Exchanger用於進行線程間的數據交換。它提供一個同步點,在這個同步點,兩個線程可以交換彼此的數據。這兩個線程通過exchange方法交換數據,如果第一個線程先執行exchange()方法,它會一直等待第二個線程也執行exchange方法,當兩個線程都到達同步點時,這兩個線程就可以交換數據,將本線程生產出來的數據傳遞給對方

import java.util.concurrent.Exchanger;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class ExchangerTest {
    private static final Exchanger<String> exgr = new Exchanger<>();
    public static void main(String[] args) {
        ExecutorService service = Executors.newFixedThreadPool(2);
        service.submit(()->{
            try {
                String A = "銀行流水A";
                System.out.println(Thread.currentThread()+"交換前:"+A);
                // 同步點:等待另一個線程到達此交換點(除非當前線程被中斷),然後將給定的對象傳送給該線程,並接收該線程的對象
                String data = exgr.exchange(A);
                System.out.println(Thread.currentThread()+"交換後:"+data);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        service.submit(()->{
            try {
                String B = "銀行流水B";
                System.out.println(Thread.currentThread()+"交換前:"+B);
                String data =  exgr.exchange(B);
                System.out.println(Thread.currentThread()+"交換前:"+data);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        service.shutdown();
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章