這個簡短的教程將介紹關於 JAX 的基礎知識。JAX 是一個 Python 庫,它通過函數轉換來增強 numpy 和 Python 代碼,使運行機器學習程序中常見的操作輕而易舉。具體來說,它會使得編寫標準 Python / numpy 代碼變得簡單,並且能夠立即執行
- 通過 autograd 的後繼計算函數的導數
- 及時編譯函數,通過 XLA 在加速器上高效運行
- 自動矢量化函數,並執行處理“批量”數據等
在本教程中,我們將通過演示它在 AGI 的一個核心問題:使用神經網絡學習異或(XOR)函數,依次介紹這些轉換。
注意:此博客文章在此處提供交互式 Jupyter notebook:https://github.com/craffel/jax-tutorial
1 JAX 只是 numpy(大多數情況下)
從本質上講,你可以將 JAX 視爲使用執行上述轉換所需的機器來增強 numpy。JAX 增強的numpy 爲 jax.numpy。除了少數例外,可以認爲 jax.numpy 與 numpy 可直接互換。作爲一般規則,當你計劃使用 JAX 的任何轉換(如計算漸變或即時編譯代碼),或希望代碼在加速器上運行時,都應該使用 jax.numpy。當 jax.numpy 不支持你的計算時,用 numpy 就行了。
import random
import itertools
import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp
from __future__ import print_function
2 背景
如前所述,我們將使用小型神經網絡學習 XOR 功能。 XOR 函數將兩個二進制數作爲輸入並輸出二進制數,如下圖所示:
我們將使用具有 3 個神經元和雙曲正切非線性的單個隱藏層的神經網絡,通過隨機梯度下降訓練交叉熵損失。然後實現此模型和損失函數。請注意,代碼與你在標準 numpy 中編寫的完全一樣。
# Sigmoid nonlinearity
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Computes our network's output
def net(params, x):
w1, b1, w2, b2 = params
hidden = np.tanh(np.dot(w1, x) + b1)
return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy loss
def loss(params, x, y):
out = net(params, x)
cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
return cross_entropy
# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
predictions = [int(net(params, inp) > 0.5) for inp in inputs]
for inp, out in zip(inputs, predictions):
print(inp, '->', out)
return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
如上所述,有些地方我們想要使用標準 numpy 而不是 jax.numpy。比如參數初始化。我們想在訓練網絡之前隨機初始化參數,這不是我們需要衍生或編譯的操作。JAX 使用自己的 jax.random 庫而不是 numpy.random,爲不同轉換的復現性(種子)提供了更好的支持。由於我們不需要以任何方式轉換參數的初始化,因此最簡單的方法就是在這裏使用標準
的 numpy.random 而不是 jax.random。
def initial_params():
return [
onp.random.randn(3, 2), # w1
onp.random.randn(3), # b1
onp.random.randn(3), # w2
onp.random.randn(), #b2
]
3 jax.grad
我們將使用的第一個轉換是 jax.grad。jax.grad 接受一個函數並返回一個新函數,該函數計算原始函數的漸變。默認情況下,相對於第一個參數進行漸變;這可以通過 jgn.grad 的 argnums 參數來控制。要使用梯度下降,我們希望能夠根據神經網絡的參數計算損失函數的梯度。爲此,使用 jax.grad(loss)就可以,它將提供一個可以調用以獲得這些漸變的函數。
loss_grad = jax.grad(loss)
# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomly
params = initial_params()
for n in itertools.count():
# Grab a single random input
x = inputs[onp.random.choice(inputs.shape[0])]
# Compute the target output
y = onp.bitwise_xor(*x)
# Get the gradient of the loss for this input/output pair
grads = loss_grad(params, x, y)
# Update parameters via gradient descent
params = [param - learning_rate * grad
for param, grad in zip(params, grads)]
# Every 100 iterations, check whether we've solved XOR
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break
4 jax.jit
雖然我們精心編寫的 numpy 代碼運行起來效果還行,但對於現代機器學習來說,我們希望這些代碼運行得儘可能快。這一般通過在 GPU 或 TPU 等不同的“加速器”上運行代碼來實現。JAX提供了一個 JIT(即時)編譯器,它採用標準的 Python / numpy 函數,經編譯可以在加速器上高效運行。編譯函數還可以避免 Python 解釋器的開銷,這決定了你是否使用加速器。總的來說,jax.jit 可以顯著加速代碼運行,且基本上沒有編碼開銷,你需要做的就是讓 JAX 爲你編譯函數。使用 jax.jit 時,即使是微小的神經網絡也可以實現相當驚人的加速度:
# Time the original gradient function
%timeit loss_grad(params, x, y)
loss_grad = jax.jit(jax.grad(loss))
# Run once to trigger JIT compilation
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)
10 loops, best of 3: 13.1 ms per loop
1000 loops, best of 3: 862 µs per loop
請注意,JAX 允許我們將變換鏈接在一起。首先,我們使用 jax.grad 獲取了丟失的梯度,然後使用 jax.jit 立即進行編譯。這是使 JAX 更強大的一個因素——除了鏈接 jax.jit 和 jax.grad之外,我們還可以多次應用 jax.grad 以獲得更高階的導數等。爲了確保訓練神經網絡經過編譯後仍然有效,我們再次對它進行訓練。請注意,訓練代碼沒有任何變化。
params = initial_params()
for n in itertools.count():
x = inputs[onp.random.choice(inputs.shape[0])]
y = onp.bitwise_xor(*x)
grads = loss_grad(params, x, y)
params = [param - learning_rate * grad
for param, grad in zip(params, grads)]
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break
5 jax.vmap
精明的讀者可能已經注意到,我們一直在一個例子上訓練我們的神經網絡。這是“真正的”隨機梯度下降;在實踐中,當訓練現代機器學習模型時,我們執行“小批量”梯度下降,在梯度下降的每個步驟中,我們對一小批示例中的損失梯度求平均值。JAX 提供了 jax.vmap,這是一個自動“矢量化”函數的轉換。這意味着它允許你在輸入的某個軸上並行計算函數的輸出。對我們來說,這意味着我們可以應用 jax.vmap 函數轉換並立即獲得損失函數漸變的版本,該版本適用於小批量示例。
jax.vmap 還可接受其他參數:
-
in_axes 是一個元組或整數,它告訴 JAX 函數參數應該對哪些軸並行化。元組應該與 vmap’d 函數的參數數量相同,或者只有一個參數時爲整數。示例中,我們將使用(None,0,0),指“不在第一個參數(params)上並行化,並在第二個和第三個參數(x和y)的第一個(第零個)維度上並行化”。
-
out_axes 類似於 in_axes,除了它指定了函數輸出的哪些軸並行化。我們在例子中使用0,表示在函數唯一輸出的第一個(第零個)維度上進行並行化(損失梯度)。
請注意,我們必須稍微修改一下訓練代碼——我們需要一次抓取一批數據而不是單個示例,並在應用它們來更新參數之前對批處理中的漸變求平均。
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))
params = initial_params()
batch_size = 100
for n in itertools.count():
# Generate a batch of inputs
x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
y = onp.bitwise_xor(x[:, 0], x[:, 1])
# The call to loss_grad remains the same!
grads = loss_grad(params, x, y)
# Note that we now need to average gradients over the batch
params = [param - learning_rate * np.mean(grad, axis=0)
for param, grad in zip(params, grads)]
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break
6 指南
這就是我們將在這個簡短的教程中介紹的內容,但這實際上涵蓋了大量的 JAX 知識。由於JAX主要是 numpy 和 Python,因此你可以利用現有知識,而不必學習基本的新框架或範例。
有關其他資源,請查看JAX GitHub:
https://github.com/google/jax 上的notebook和示例目錄。