ThreadLocal實現原理分析

大概有一年多的時間沒有更新過文章了,要想輸出一篇優質的文章需要耗費很多精力。可能是之前太過於懶惰了吧,經過一段精力的消耗,漸漸地失去了一些動力。但是寫文章雖然耗時,但是有個好處就是在複習一些知識點的時候,只需要查看之前寫的博客,在很短的時間內就能把知識點回想起來。曾經的初中老師總是嘮叨說好記性不如爛筆頭。看來是“誠不欺我呀!”。希望之後還是能保持一定的更新節奏,把對技術的思考都記錄下來。跟大家一起分享知識

概述

今天主要是記錄一下ThreadLocal的實現原理。引用下官方文檔對ThreadLocal的註釋

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its get or set method) has its own, independently initialized copy of the variable. ThreadLocal instances are typically private static fields in classes that wish to associate state with a thread (e.g., a user ID or Transaction ID).

這句話翻譯過來就是 “該類提供線程局部變量。 這些變量與它們的正常對應物的不同之處在於,訪問一個變量的每個線程(通過其get或set方法)都有自己獨立初始化的變量副本。”

我們知道一個變量的作用域大概分爲"全局作用域"(static變量)、“類作用域”(成員變量)、“方法作用域”(方法內的局部變量)、“代碼塊作用域”(代碼塊的局部變量)。那麼ThreadLocal可以爲線程提供局部變量,那說明提供的變量的作用域是Thread類作用域。換句話說可以通過ThreadLocal給Thread提供拓展新的成員變量的功能,是不是有點像Kotlin的拓展新方法,新變量的功能有點類似?如果這樣描述你覺得對ThreadLocal更容易理解,那你就認爲ThreadLocal可以給線程提供新的成員變量,只不過這個成員變量是沒有名稱的。它的賦值只能通過ThreadLocal對象的set()方法,獲取該變量的值只能通過ThreadLocal的get()方法。下面通過一個簡單的例子來講解下ThreadLocal是怎麼給Thread提供局部變量的

簡單的例子

import java.util.concurrent.TimeUnit;

public class TestThreadLocal {
    //1
    ThreadLocal<String> stringThreadLocal = new ThreadLocal<>();

    public void setThreadValue(String value) {
        //2
        stringThreadLocal.set(value);
    }

    public String getThreadValue() {
        //3
        return stringThreadLocal.get();
    }

    public static void main(String[] args) {
        TestThreadLocal testThreadLocal = new TestThreadLocal();
        //4
        Thread t1 = new Thread(new Runnable() {
            @Override
            public void run() {
                testThreadLocal.setThreadValue("t1");
                try {
                    TimeUnit.SECONDS.sleep(1);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(testThreadLocal.getThreadValue() + " in thread1");
            }
        });
        //5
        Thread t2 = new Thread(new Runnable() {
            @Override
            public void run() {
                testThreadLocal.setThreadValue("t2");
                try {
                    TimeUnit.SECONDS.sleep(1);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(testThreadLocal.getThreadValue() + " in thread2");
            }
        });
        t1.start();
        t2.start();
        try {
            t1.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        try {
            t2.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

程序運行的結果是

t1 in thread1
t2 in thread2

註釋//1處定義了一個ThreadLocal變量,爲線程提供String類型的線程局部變量

註釋//2處通過ThreadLocal的set方法爲線程局部變量賦值

註釋//3處通過ThreadLocal的get方法獲取線程局部變量的值

註釋//4 //5分別創建線程t1和t2。並且通過ThreadLocal爲t1 t2賦值以及獲取值

源碼分析

在分析源碼前我們先來看下ThreadLocal相關的類圖

在這裏插入圖片描述

  1. ThreadLocal類定義了set和get方法
  2. Thread有個名爲threadLocals 類型是ThreadLocal.ThreadLocalMap的成員變量。ThreadLocalMap是什麼呢。顧名思義肯定是存儲鍵值對的。類似HashMap<K,V>。這裏有兩個疑問,第一既然是Map接口那麼存儲的鍵值對是什麼?第二爲什麼要用ThreadLocalMap而不是HashMap呢?
  3. ThreadLocal.ThreadLocalMap類是Map結構,第一key存儲的是ThreadLocal<?>對象,value存儲的是Thread的局部變量的值,第二map的key-value對應的Entry是一個WeakReference,而且該WeakReference引用的是Key。也就是說當ThreadLocal對象被GC回收了之後,Map對應的Entry也會被回收掉,同時給Thread提供的局部變量也會被回收掉,這樣設計不會造成ThreadLocal的內存泄漏,比HashMap<ThreadLocal,T>好

接下來從源碼角度分析

//ThreadLocal.set(T value)
public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    

//ThreadLocal.getMap(Thread t)
ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
}


//ThreadLocal.ThreadLocalMap.set(ThreadLocal<?> key,Object value)

private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.

            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

ThreadLocal.set(T value)流程如下

  1. 獲取到當前線程
  2. 獲取當前線程的ThreadLocalMap對象
  3. 如果ThreadLocalMap對象不爲空,把鍵值對set到map中,如果爲空創建map並初始化

ThreadLocal.get()方法流程"可猜而知"

  1. 獲取當前線程
  2. 獲取當前線程的ThreadLocalMap對象
  3. 如果Map對象爲空,返回默認值,如果不爲空則根據key獲取map中的value

那麼看下源碼來印證下猜測吧

//ThreadLocal.get()
 public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

上述源碼,剛好印證了猜想是對的。

最後留下一個問題

既然ThreadLocal爲線程提供局部的變量,那麼該變量只能在當前線程中賦值和訪問。那麼真的沒有辦法在t1中訪問和修改t2中的局部變量嗎?當然有了通過反射。

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;

public class TestThreadLocal {
    ThreadLocal<String> stringThreadLocal = new ThreadLocal<>();

    public void setThreadValue(String value) {
        stringThreadLocal.set(value);
    }

    public String getThreadValue() {
        return stringThreadLocal.get();
    }

    public static void main(String[] args) {
        TestThreadLocal testThreadLocal = new TestThreadLocal();
        final Thread t1 = new Thread(new Runnable() {
            @Override
            public void run() {
                testThreadLocal.setThreadValue("t1");
                try {
                    TimeUnit.SECONDS.sleep(1);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(testThreadLocal.getThreadValue() + " in thread1");
                while (true){

                }
            }
        });
        Thread t2 = new Thread(new Runnable() {
            @Override
            public void run() {
                testThreadLocal.setThreadValue("t2");
                try {
                    TimeUnit.SECONDS.sleep(1);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(testThreadLocal.getThreadValue() + " in thread2");
                try {
                  Field field =  Thread.class.getDeclaredField("threadLocals");
                  field.setAccessible(true);
                  Object map = field.get(t1);
                  Class clazz = Class.forName("java.lang.ThreadLocal$ThreadLocalMap");
                  Method method  = clazz.getDeclaredMethod("getEntry",ThreadLocal.class);
                  method.setAccessible(true);
                  Object entry  = method.invoke(map,testThreadLocal.stringThreadLocal);
                  Class clazz2 = Class.forName("java.lang.ThreadLocal$ThreadLocalMap$Entry");
                    Field field1 =clazz2.getDeclaredField("value");
                    field1.setAccessible(true);
                    System.out.println(field1.get(entry)+" in thread 2");

                } catch (NoSuchFieldException e) {
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                } catch (NoSuchMethodException e) {
                    e.printStackTrace();
                } catch (InvocationTargetException e) {
                    e.printStackTrace();
                }
            }
        });
        t1.start();
        try {
            TimeUnit.SECONDS.sleep(2);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        t2.start();
        try {
            t1.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        try {
            t2.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章