CountDownLatch
java.util.concurrent.CountDownLatch
發令槍,允許一個或多個線程等待其他線程完成操作
主線程需要等待所有的子線程執行完後進行彙總,join
方法可以實現這一點,但是不夠靈活。
使用
public class CountDownLatchTest {
// 計數器
private static CountDownLatch countDownLatch = new CountDownLatch(2);
public static void main(String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
executorService.submit(()->{
try {
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
// 計數器減一
countDownLatch.countDown();
}
System.out.println("Thread A Over");
});
executorService.submit(()->{
try {
TimeUnit.SECONDS.sleep(2);
} catch (InterruptedException e) {
e.printStackTrace();
}finally {
countDownLatch.countDown();
}
System.out.println("Thread B Over");
});
// 等待子線池執行完
countDownLatch.await();
System.out.println("All Child thread Over!");
executorService.shutdown();
}
}
源碼探究
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 循環進行CAS,直到當前線程成功完成CAS使計數器值(state)減1,並更新到state
for (;;) {
int c = getState();
// 如果當前狀態值爲0則直接返回
if (c == 0)
return false;
// 使用CAS讓計數器值減一
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
// 內部使用AQS實現
private final Sync sync;
// 構造函數 傳入計數器
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
// 把計數器器的值賦予AQS的狀態變量state
this.sync = new Sync(count);
}
// 線程調用await方法,當前線程會被阻塞,直到下面情況發生才返回
// 1. 所有線程調用countDown方法,計數器的值爲0
// 2. 其他線程調用用當前線程的interrupte()方法中斷了當前線程,當前線程拋出InterruptedException返回
public void await() throws InterruptedException {
// 委託sysnc調用AOS的acquireSharedInterruptibly方法
sync.acquireSharedInterruptibly(1);
}
// 帶有超時時間阻塞等待方法
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// 計算器的值遞減,遞減爲0則喚醒所有調用await方法而被阻塞的線程
public void countDown() {
sync.releaseShared(1);
}
// 返回當前計算器值
public long getCount() {
return sync.getCount();
}
}
總結: 相比使用join方法實現線程間同步,CountDownLatch更具有靈活性和方便性,CountDownLatch使用AQS實現的,使用AQS的狀態變量來存放計數器的值,構造函數初始化狀態值(計數器值),當線程調用await()方法後當前線程會被放入AQS阻塞隊列等待計數器爲0在返回,多個線程調用countDown()時原子(cas)遞減AQS的狀態值,計數器值減1,當計數器值變爲0是,當前線程調用AQS的doReleaseShared方法激活由調用await()而被阻塞的線程
CyclicBarrier
java.util.concurrent.CyclicBarrier
可循環(Cyclic)的屏障(Barrier)。迴環屏障,讓一組線程到達屏障(屏障點/同步點)時,屏障纔會打開,所有被屏障攔截的線程纔會繼續工作
使用
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class CyclicBarrierTest {
private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,()->{
System.out.println(Thread.currentThread()+"=====達到屏障點執行=");
// todo something
});
public static void main(String[] args) {
ExecutorService executorService = Executors.newFixedThreadPool(2);
// 線程A
executorService.submit(()->{
try {
System.out.println(Thread.currentThread()+"step1");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step2");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step3");
} catch (Exception e) {
e.printStackTrace();
}
});
// 線程B
executorService.submit(()->{
try {
System.out.println(Thread.currentThread()+"step1");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step2");
cyclicBarrier.await();
System.out.println(Thread.currentThread()+"step3");
} catch (Exception e) {
e.printStackTrace();
}
});
executorService.shutdown();
}
}
注意:
- 對於指定計數器值parties,若由於某種原因,沒有足夠的線程調用CyclicBarrier的await,則所有調用await的線程都會被阻塞
- CyclicBarrier也可以調用await(timeout,unit)設置超時時間,在設定時間內,如果沒有足夠線程到達,則解除阻塞狀態,繼續工作
- 通過reset重置計數,會使得進入await的線程出現
BrokenBarrierException
4)如果採用是CyclicBarrier(int parties,Runnable barrierAction)
構造方法,執行barrierAction
操作的是最後一個到達的線程
源碼探究
CyclicBarrier
基於ReetrantLock
和Condition
實現,CyclicBarrier 可以有不止一個柵欄,因爲它的柵欄(Barrier)可以重複使用(Cyclic)
public class CyclicBarrier {
private static class Generation {
//記錄當前屏障是否被打破 由於在鎖內使用,所以不需要申明volatile
boolean broken = false;
}
/** The lock for guarding barrier entry */
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
// lock的條件變量支持線程間使用await和sigal操作進行同步
private final Condition trip = lock.newCondition();
/** The number of parties
* 記錄線程的個數,表示多少個線程調用await後,所有線程纔會衝破屏障繼續往下運行
* */
private final int parties;
/* The command to run when tripped */
// 當所有的線程到達了屏障點,最後一個線程執行
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();
/** count一開始等於parties,每當線程調用await方法就遞減1,當count爲0就表示所有線程到了屏障點*/
private int count;
/**
* Updates state on barrier trip and wakes up everyone.
* Called only while holding lock.
*/
private void nextGeneration() {
// signal completion of last generation 喚醒條件隊列裏面阻塞線程
trip.signalAll();
// set up next generation 重置CyclicBarrier
count = parties;
generation = new Generation();
}
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
// dowait實現了CyclicBarrier的核心功能
// timed 是否設置超時時間
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
int index = --count;
// (1) index==0則說明所有線程都到了屏障點,此時執行初始化時傳遞的任務時傳遞任務
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
// (2) 執行任務
if (command != null)
command.run();
ranAction = true;
// (3) 激活其他因調用await方法而被阻塞的線程,並重置CyclicBarrier
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
// (4) 如果index != 0
for (;;) {
try {
// (5) 沒有設置超時時間
if (!timed)
trip.await();
// (6) 設置了超時時間
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public int getParties() {
return parties;
}
// 阻塞方法
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
// 待超時時間的阻塞方法
// parties個線程都調用了await()方法,也就是線程都到了屏障點,這時候返回true;
// 設置的超時時間到了後返回false;
// 其他線程調用當前線程的interrupt()方法中斷了當前線程,則當前線程會拋出InterruptedException異常然後返回;
// 與當前屏障點關聯的Generation對象的broken標誌被設置爲true時,會拋出BrokenBarrierException異常,然後返回
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
// 重置Barrier
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}
Semaphore
Semaphore(信號量)是用來控制同時訪問特定資源的線程數量,它通過協調各個線程,以保證合理的使用公共資源
使用
場景: Semaphore可以用於做流量控制,特別是公用資源有限的應用場景,比如數據庫連接。
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
public class SemaphoreTest {
static class Car extends Thread {
private int num;
private Semaphore semaphore;
public Car(int num, Semaphore semaphore) {
this.num = num;
this.semaphore = semaphore;
}
@Override
public void run() {
try {
// 獲得一個令牌,如果拿不到令牌,就會阻塞
semaphore.acquire();
System.out.println("第"+num+" 搶佔一個車位");
TimeUnit.SECONDS.sleep(2);
System.out.println("第"+num+" 開走");
semaphore.release();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public static void main(String[] args) {
Semaphore semaphore = new Semaphore(5);
// ExecutorService executorService = Executors.newFixedThreadPool(5);
for (int i = 0; i < 10; i++) {
new Car(i,semaphore).start();
// executorService.submit(new Car(i,semaphore));
}
// executorService.shutdown();
}
}
源碼探究
public class Semaphore implements java.io.Serializable {
private static final long serialVersionUID = -3222578661600680210L;
private final Sync sync;
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);
}
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
// 獲取當前信號量值
int available = getState();
// 計算當前剩餘值
int remaining = available - acquires;
// 如果剩餘值小於0說明當前信號量個數不滿足需求
// 大於0且CAS操作成功,返回剩餘值
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
final void reducePermits(int reductions) {
for (;;) {
int current = getState();
// 不允許縮減值爲負數
int next = current - reductions;
if (next > current) // underflow
throw new Error("Permit count underflow");
if (compareAndSetState(current, next)) // CAS 設置縮減後的許可證數量
return;
}
}
final int drainPermits() {
for (;;) {
int current = getState();
// 直接把剩餘的許可證數量設置爲0
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
// 公平策略
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
// 如果當前線程不位於對頭,則阻塞
// hasQueuedPredecessors 來保證公平性
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
// 獲取一個信號量值,未獲取到會阻塞
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
public boolean tryAcquire(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void release() {
sync.releaseShared(1);
}
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public void acquireUninterruptibly(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireShared(permits);
}
public boolean tryAcquire(int permits) {
if (permits < 0) throw new IllegalArgumentException();
return sync.nonfairTryAcquireShared(permits) >= 0;
}
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
// 釋放信號量
public void release(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.releaseShared(permits);
}
// 返回信號量中對當前可用的許可證數
public int availablePermits() {
return sync.getPermits();
}
// 獲取立即可用的所有許可證個數,並將可用許可證置0
public int drainPermits() {
return sync.drainPermits();
}
protected void reducePermits(int reduction) {
if (reduction < 0) throw new IllegalArgumentException();
sync.reducePermits(reduction);
}
// 是否是公平策略
public boolean isFair() {
return sync instanceof FairSync;
}
// 是否有線程正在等待獲取許可證
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
// 返回正則等待許可證的線程數
public final int getQueueLength() {
return sync.getQueueLength();
}
// 返回所有等待許可證的線程集合
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
public String toString() {
return super.toString() + "[Permits = " + sync.getPermits() + "]";
}
}
Exchanger
Exchanger(交換者)是一個用於線程間協作的工具類。Exchanger用於進行線程間的數據交換。它提供一個同步點,在這個同步點,兩個線程可以交換彼此的數據。這兩個線程通過exchange方法交換數據,如果第一個線程先執行exchange()方法,它會一直等待第二個線程也執行exchange方法,當兩個線程都到達同步點時,這兩個線程就可以交換數據,將本線程生產出來的數據傳遞給對方
import java.util.concurrent.Exchanger;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class ExchangerTest {
private static final Exchanger<String> exgr = new Exchanger<>();
public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(2);
service.submit(()->{
try {
String A = "銀行流水A";
System.out.println(Thread.currentThread()+"交換前:"+A);
// 同步點:等待另一個線程到達此交換點(除非當前線程被中斷),然後將給定的對象傳送給該線程,並接收該線程的對象
String data = exgr.exchange(A);
System.out.println(Thread.currentThread()+"交換後:"+data);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
service.submit(()->{
try {
String B = "銀行流水B";
System.out.println(Thread.currentThread()+"交換前:"+B);
String data = exgr.exchange(B);
System.out.println(Thread.currentThread()+"交換前:"+data);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
service.shutdown();
}
}