ConcurrentHashMap源碼分析

前言

ConcurrentHashMap是java.util.concurrent包下的一個類,它設計出來是用來在某些情況下替換Hashtable的。相比Hashtable它能夠更加高效的進行多線程操作,並不一定需要像Hashtable一樣,當一個線程佔有鎖的時候其他的線程都必須進入阻塞狀態,因此在多線程環境下它更加的高效。至於,ConcurrentHashMap是否能夠完全替代Hashtable這個問題,博文中有分析到:https://my.oschina.net/hosee/blog/675423。但是,同時它也降低了對數據一致性的要求。在這裏額外提一下,java.util.concurrent包中的併發容器,設計出來是用來替換同步容器(多線程環境下,一個線程佔有鎖的時候其他線程必需進入等待狀態,比如Hashtable,Vector),以提供更加高效的併發。

本文將基於JDK1.7的源碼進行分析,JDK1.8之後再寫。

設計思路

CocurrentHashMap使用分段鎖(segment)的方式來減少鎖的粒度,它有一個Segment[]屬性。不同的線程訪問不同segment裏面的數據不會產生阻塞,只有多個線程訪問同一個segment纔會產生鎖的競爭。segment中有一個HashEntry數組,同時它還繼承了ReentrantLock,所以它就類似於一個Hashtable。

/**
 * The segments, each of which is a specialized hash table.
 */
final Segment<K,V>[] segments;

static final class Segment<K,V> extends ReentrantLock implements Serializable
    transient volatile HashEntry<K,V>[] table;

static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;
}

HashEntry節點和HashMap中的Entry稍有不同的就是,value和next節點都有volatile關鍵詞修飾,這個是爲了保證多線程環境下的可見性。ConcurrentHashMap的結構圖大致如下:

繪圖2

Hash算法

無論是HashMap還是ConcurrentHashMap的基礎都是hash算法,下面是它的hash算法相關的源碼:

private int hash(Object k) {
    int h = hashSeed;

    if ((0 != h) && (k instanceof String)) {
        return sun.misc.Hashing.stringHash32((String) k);
    }

    h ^= k.hashCode();

    // Spread bits to regularize both segment and index locations,
    // using variant of single-word Wang/Jenkins hash.
    h += (h <<  15) ^ 0xffffcd7d;
    h ^= (h >>> 10);
    h += (h <<   3);
    h ^= (h >>>  6);
    h += (h <<   2) + (h << 14);
    return h ^ (h >>> 16);
} 

public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
        throw new NullPointerException();
    int hash = hash(key);
    int j = (hash >>> segmentShift) & segmentMask;
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        s = ensureSegment(j);
    return s.put(key, hash, value, false);
}

可能有人會有疑問,爲什麼能夠直接通過Object的hashcode()方法得到hashcode的值之後還需要再單獨定義一個hash()方法進行再哈希?因爲直接計算索引受限於segments.length轉換爲2進制後的有效位數。具體看下面這個例子體會一下:

對於segments.length == 16的ConcurrentHashMap,(segments.length - 1) == 15轉換爲2進制數後爲1111,有效位數爲4位。現有4個key,hascode的值分別爲:15,31,63,127,它們轉換爲2進制數後對應的值分別爲:1111,11111,111111,111111。現在對它們求索引:

15 & 15 ==> 1111 & 1111 = 1111

31 & 15 ==> 11111 & 01111 = 01111

63 & 15 ==> 111111 & 001111 = 001111

127 & 15 ==> 1111111 & 0001111 = 0001111

最終求出索引的值轉換爲10進制數後都是15,可以看出直接用hascode的值求索引,受限於length的轉換爲2進制的有效位數,比較容易產生hash衝突。爲了解決這個問題,就需要利用key的hascode轉換爲2進制的後的有效位數的不同,進行再hash運算,最終使得進行&運算的時候有效位數不同。仍然是上面的這個例子,看一下通過Wang/Jenkins hash算法之後,爲了方便閱讀將數據轉換爲32位的2進制數據,不足位用0補齊:

0100 0111 0110 0111 1101 1010 0100 1110
1111 0111 0100 0011 0000 0001 1011 1000
0111 0111 0110 1001 0100 0110 0011 1110
1000 0011 0000 0000 1100 1000 0001 1010

顯而易見的是,通過再hash之後,對有效bit進行拆分,使得最後4位的bit不相同。Wang/Jenkins hash算法詳細計算案例請參考:http://www.goworkday.com/2010/03/19/single-word-wangjenkins-hash-concurrenthashmap/。事實上Wang/Jenkins hash算法具有很好的分佈性,它有一個特點就是雪崩性(只需要改變輸入數據的一個bit爲就會使得輸出數據差異很大)。

從源碼中可以清楚的看明白hash算法的兩步:

  1. 首先使用Wang/Jenkis hash算法確定hashcode的值
  2. 通過hashcode確定segment的索引值

計算索引的時候出現了兩個陌生的屬性segmentShift和segmentMask,這兩個屬性在構造函數中確定的,下面是部分代碼:

    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // Find power-of-two sizes best matching arguments
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // create segments and segments[0]
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }

可以看出segmentShift就相當於segment.length - 1,而segmentShift等於32-n,這個segments.length轉換爲2^n的指數。

從構造函數的源碼中,有一個concurrencyLevel重要參數傳入,它是用來確定segment數組的長度的。segment個數是大於等於concurrencyLevel的最小二次冪整數。所以,concurrencyLevel就相當於一個併發度,它的值不宜設置的太小,設置太小會產生頻繁的鎖競爭,設置太大會使得在同一segment的HashEntry分散到不同segment中,降低cpu的緩存。

延遲加載鎖

在構造函數中,只是創建了索引爲0處的segment,其他的segment採用延遲加載的策略。當key定位segment的時候,首先進行檢測,看是否存在,如果不存在就調用ensureSegment()方法進行創建。具體源碼如下:

 private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    // 使用CAS算法創建segment,不需要鎖住該方法,提高了創建segment的速度
                    // CAS算法的原理就是不斷比較當前內存中的對象和你指定的對象是否相等,
                    // 如果相等則接受修改,不相等則不接受,因爲內存中的對象已經不是最新的對象
                    // 如果再進行修改,則會覆蓋其他線程修改的最新的值
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

segment的創建使用了雙重檢查,總體來說使用UNSAFE對象的兩個方法。getObjectVolatile提供了原子語義再加上CAS算法實現無鎖創建segment。

put方法

從ConcurrentHashMap的put方法可以看出來,put方法最終被代理到了segment中,具體源碼:

final V put(K key, int hash, V value, boolean onlyIfAbsent) {
        //首先進行第一次嘗試獲得鎖,如果獲得則創建一個null的node節點
        // 如果未獲得鎖就調用scanAndLockForPut方法
        HashEntry<K,V> node = tryLock() ? null :
            scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            HashEntry<K,V>[] tab = table;
            int index = (tab.length - 1) & hash;
            HashEntry<K,V> first = entryAt(tab, index);
            for (HashEntry<K,V> e = first;;) {
                if (e != null) {
                    K k;
                    // key存在直接替換value
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        oldValue = e.value;
                        if (!onlyIfAbsent) {
                            e.value = value;
                            ++modCount;
                        }
                        break;
                    }
                    e = e.next;
                }
                else {
                    // key不存在將新node節點設置爲頭節點
                    if (node != null)
                        node.setNext(first);
                    else
                        node = new HashEntry<K,V>(hash, key, value, first);
                    int c = count + 1;
                    //檢查是否需要擴容,注意這裏只是對segment進行擴容,而非ConcurrentHashMap
                    if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                        rehash(node);
                    else    //將新鏈表的頭節點設置到數組中
                        setEntryAt(tab, index, node);
                    ++modCount;
                    count = c;
                    oldValue = null;
                    break;
                }
            }
        } finally {
            // 最終釋放鎖
            unlock();
        }
        return oldValue;
    }

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
        HashEntry<K,V> first = entryForHash(this, hash);
        HashEntry<K,V> e = first;
        HashEntry<K,V> node = null;
        int retries = -1; // negative while locating node
        while (!tryLock()) {
            HashEntry<K,V> f; // to recheck first below
            if (retries < 0) {
                if (e == null) {
                    if (node == null) // speculatively create node
                        node = new HashEntry<K,V>(hash, key, value, null);
                    retries = 0;
                }   //存在對應的key
                else if (key.equals(e.key))
                    retries = 0;
                else    //遍歷鏈表,這個是爲了被cpu所緩存,提升put方法的時間
                    e = e.next;
            }   //超過一定的次數後直接申請鎖
            else if (++retries > MAX_SCAN_RETRIES) {
                lock();
                break;
            }       //頭節點變化了需要重新進行遍歷
            else if ((retries & 1) == 0 &&
                     (f = entryForHash(this, hash)) != first) {
                e = first = f; // re-traverse if entry changed
                retries = -1;
            }
        }
        return node;
    }

segment的put方法首先第一次嘗試獲得鎖,如果成功則設置node節點爲null,不成功則調用scanAndLockForPut創建一個節點。put方法的一般流程爲:

遍歷鏈表,遍歷的時候看key是否存在。如果存在直接替換value即可,如果不存在則用新創建的node節點替換爲原鏈表的頭節點,然後將新的頭節點設置到數組中。這裏有一處JDK1.7的優化就是,在scanAndLockForPut方法中,首先會嘗試獲得鎖,在獲得鎖的過程中會遍歷鏈表,使得鏈表被cpu所緩存提高後續put方法的時間;同時在遍歷的過程中也會檢查頭節點是否改變(put,remove等方法會改變頭節點),如果頭節點改變就需要重新進行遍歷。在嘗試獲得鎖一定次數之後,會直接調用lock()方法獲得鎖。

最後需要注意的是:

put方法中,鏈接新節點的下一個節點(HashEntry.setNext())以及將鏈表寫入到數組中(setEntryAt())都是通過Unsafe的putOrderedObject()方法來實現,這裏並未使用具有原子寫語義的putObjectVolatile()的原因是:JMM會保證獲得鎖到釋放鎖之間所有對象的狀態更新都會在鎖被釋放之後更新到主存,從而保證這些變更對其他線程是可見的。

從這裏可以看出ConcurrentHashMap的弱一致性:當一對key和value通過調用put方法存儲到ConcurrentHashMap中,但是另外一個線程想通過該key馬上調用get()方法(get方法不需要獲取鎖)獲取的value的時候,可能獲取到的是一個null。因爲它必須等到釋放鎖之後,纔會將最新的值更新到主存中

rehash方法

private void rehash(HashEntry<K,V> node) {

    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            int idx = e.hash & sizeMask;
            if (next == null)   //  單個節點
                newTable[idx] = e;
            else { // 鏈表
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                    int k = last.hash & sizeMask;
                    // 尋找index不變的第一個節點
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                //  該節點之後的節點的index都不變,直接將頭節點設置到newTable中即可
                newTable[lastIdx] = lastRun;
                // 對之前的節點進行復制和重排序
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    // 使用頭插法
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

該方法首先創建一個原先2倍長的table,然後進行遍歷複製,擴容完成之後將新節點添加到newTable中。在遍歷的過程中有一處優化,它首先會找到index不變的第一個節點,直接將它作爲頭節點設置到數組中即可。而對它之前的節點,直接進行復制和重排序,然後使用頭插法設置到數值中。

get方法

public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        int h = hash(key);
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get方法比較的簡單,它能夠實現無鎖化操作的主要原因就是使用UNSAFE對象的getObjectVolatile()方法提供原子語義,來獲取segment和頭節點。

size方法

public int size() {
        // Try a few times to get accurate count. On failure due to
        // continuous async changes in table, resort to locking.
        final Segment<K,V>[] segments = this.segments;
        int size;
        boolean overflow; // true if size overflows 32 bits
        long sum;         // sum of modCounts
        long last = 0L;   // previous sum
        int retries = -1; // first iteration isn't retry
        try {
            for (;;) {
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                for (int j = 0; j < segments.length; ++j) {
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        // 是否溢出,即超過整型最大值
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

首先進行循環遍歷每一個segment,來計算整個ConcurrentHashMap的modCount。遍歷計算ConcurrentHashMap的modCount的時候會計算多次,如果相鄰的2次,modCount加起來相等,將每一個segment的count相加,作爲最後的輸出結果,輸出的時候還需要檢查數值是否溢出。若循環超過一定次數之後,則會對每一個segment強制加鎖,如果segment不存在則直接創建處理。一般來說,應該避免在多線程環境下使用size和containsValue方法。原因有2點:

  1. 強制創建每一個segment(即使它當中不存在元素)並且進行加鎖損耗太大,容易造成不必要的開銷和線程阻塞。
  2. 即使強制加鎖計算出來的size仍然是一個粗略的值,因爲對當前segment進行加鎖,其他的線程仍然可以對之前的segment調用put方法添加元素。

參考

ConcurrentHashMap總結
深入分析ConcurrentHashMap

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章