關於要替代TensorFlow的JAX,你知道多少?

這個簡短的教程將介紹關於 JAX 的基礎知識。JAX 是一個 Python 庫,它通過函數轉換來增強 numpy 和 Python 代碼,使運行機器學習程序中常見的操作輕而易舉。具體來說,它會使得編寫標準 Python / numpy 代碼變得簡單,並且能夠立即執行

  • 通過 autograd 的後繼計算函數的導數
  • 及時編譯函數,通過 XLA 在加速器上高效運行
  • 自動矢量化函數,並執行處理“批量”數據等

在本教程中,我們將通過演示它在 AGI 的一個核心問題:使用神經網絡學習異或(XOR)函數,依次介紹這些轉換。

注意:此博客文章在此處提供交互式 Jupyter notebook:https://github.com/craffel/jax-tutorial

1 JAX 只是 numpy(大多數情況下)

從本質上講,你可以將 JAX 視爲使用執行上述轉換所需的機器來增強 numpy。JAX 增強的numpy 爲 jax.numpy。除了少數例外,可以認爲 jax.numpy 與 numpy 可直接互換。作爲一般規則,當你計劃使用 JAX 的任何轉換(如計算漸變或即時編譯代碼),或希望代碼在加速器上運行時,都應該使用 jax.numpy。當 jax.numpy 不支持你的計算時,用 numpy 就行了。

import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

from __future__ import print_function

2 背景

如前所述,我們將使用小型神經網絡學習 XOR 功能。 XOR 函數將兩個二進制數作爲輸入並輸出二進制數,如下圖所示:

image

我們將使用具有 3 個神經元和雙曲正切非線性的單個隱藏層的神經網絡,通過隨機梯度下降訓練交叉熵損失。然後實現此模型和損失函數。請注意,代碼與你在標準 numpy 中編寫的完全一樣。

# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

如上所述,有些地方我們想要使用標準 numpy 而不是 jax.numpy。比如參數初始化。我們想在訓練網絡之前隨機初始化參數,這不是我們需要衍生或編譯的操作。JAX 使用自己的 jax.random 庫而不是 numpy.random,爲不同轉換的復現性(種子)提供了更好的支持。由於我們不需要以任何方式轉換參數的初始化,因此最簡單的方法就是在這裏使用標準
的 numpy.random 而不是 jax.random。

def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

3 jax.grad

我們將使用的第一個轉換是 jax.grad。jax.grad 接受一個函數並返回一個新函數,該函數計算原始函數的漸變。默認情況下,相對於第一個參數進行漸變;這可以通過 jgn.grad 的 argnums 參數來控制。要使用梯度下降,我們希望能夠根據神經網絡的參數計算損失函數的梯度。爲此,使用 jax.grad(loss)就可以,它將提供一個可以調用以獲得這些漸變的函數。

loss_grad = jax.grad(loss)

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

image

4 jax.jit

雖然我們精心編寫的 numpy 代碼運行起來效果還行,但對於現代機器學習來說,我們希望這些代碼運行得儘可能快。這一般通過在 GPU 或 TPU 等不同的“加速器”上運行代碼來實現。JAX提供了一個 JIT(即時)編譯器,它採用標準的 Python / numpy 函數,經編譯可以在加速器上高效運行。編譯函數還可以避免 Python 解釋器的開銷,這決定了你是否使用加速器。總的來說,jax.jit 可以顯著加速代碼運行,且基本上沒有編碼開銷,你需要做的就是讓 JAX 爲你編譯函數。使用 jax.jit 時,即使是微小的神經網絡也可以實現相當驚人的加速度:

# Time the original gradient function
%timeit loss_grad(params, x, y)
loss_grad = jax.jit(jax.grad(loss))
# Run once to trigger JIT compilation
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)

10 loops, best of 3: 13.1 ms per loop

1000 loops, best of 3: 862 µs per loop

請注意,JAX 允許我們將變換鏈接在一起。首先,我們使用 jax.grad 獲取了丟失的梯度,然後使用 jax.jit 立即進行編譯。這是使 JAX 更強大的一個因素——除了鏈接 jax.jit 和 jax.grad之外,我們還可以多次應用 jax.grad 以獲得更高階的導數等。爲了確保訓練神經網絡經過編譯後仍然有效,我們再次對它進行訓練。請注意,訓練代碼沒有任何變化。

params = initial_params()

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0])]
    y = onp.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

image

5 jax.vmap

精明的讀者可能已經注意到,我們一直在一個例子上訓練我們的神經網絡。這是“真正的”隨機梯度下降;在實踐中,當訓練現代機器學習模型時,我們執行“小批量”梯度下降,在梯度下降的每個步驟中,我們對一小批示例中的損失梯度求平均值。JAX 提供了 jax.vmap,這是一個自動“矢量化”函數的轉換。這意味着它允許你在輸入的某個軸上並行計算函數的輸出。對我們來說,這意味着我們可以應用 jax.vmap 函數轉換並立即獲得損失函數漸變的版本,該版本適用於小批量示例。

jax.vmap 還可接受其他參數:

  • in_axes 是一個元組或整數,它告訴 JAX 函數參數應該對哪些軸並行化。元組應該與 vmap’d 函數的參數數量相同,或者只有一個參數時爲整數。示例中,我們將使用(None,0,0),指“不在第一個參數(params)上並行化,並在第二個和第三個參數(x和y)的第一個(第零個)維度上並行化”。

  • out_axes 類似於 in_axes,除了它指定了函數輸出的哪些軸並行化。我們在例子中使用0,表示在函數唯一輸出的第一個(第零個)維度上進行並行化(損失梯度)。

請注意,我們必須稍微修改一下訓練代碼——我們需要一次抓取一批數據而不是單個示例,並在應用它們來更新參數之前對批處理中的漸變求平均。

loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

image

6 指南

這就是我們將在這個簡短的教程中介紹的內容,但這實際上涵蓋了大量的 JAX 知識。由於JAX主要是 numpy 和 Python,因此你可以利用現有知識,而不必學習基本的新框架或範例。

有關其他資源,請查看JAX GitHub:

https://github.com/google/jax 上的notebook和示例目錄。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章