提交 97797975 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Benchmark scan in JAX backend

上级 545e58f7
......@@ -4,7 +4,7 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import function, shared
from pytensor import function, ifelse, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.scan import until
......@@ -12,7 +12,7 @@ from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
from pytensor.tensor import random
from pytensor.tensor.math import gammaln, log
from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -189,96 +189,6 @@ def test_scan_while():
compare_jax_and_py([], [xs], [])
def test_scan_SEIR():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
# sequences
at_C = vector("C_t", dtype="int32", shape=(8,))
at_D = vector("D_t", dtype="int32", shape=(8,))
# outputs_info (initial conditions)
st0 = lscalar("s_t0")
et0 = lscalar("e_t0")
it0 = lscalar("i_t0")
logp_c = scalar("logp_c")
logp_d = scalar("logp_d")
# non_sequences
beta = scalar("beta")
gamma = scalar("gamma")
delta = scalar("delta")
# TODO: Use random streams when their JAX conversions are implemented.
# trng = pytensor.tensor.random.RandomStream(1234)
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
fn=seir_one_step,
sequences=[at_C, at_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX)
beta_val, gamma_val, delta_val = (
np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753]
)
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_jax_and_py(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
test_input_vals,
jax_mode="JAX",
)
def test_scan_mitsot_with_nonseq():
a_pt = scalar("a")
......@@ -420,3 +330,240 @@ def test_dynamic_sequence_length():
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
np.testing.assert_allclose(f([]), [])
np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4]))
def SEIR_model_logp():
"""Setup a Scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
# sequences
C_t = vector("C_t", dtype="int32", shape=(1200,))
D_t = vector("D_t", dtype="int32", shape=(1200,))
# outputs_info (initial conditions)
st0 = scalar("s_t0")
et0 = scalar("e_t0")
it0 = scalar("i_t0")
# non_sequences
beta = scalar("beta")
gamma = scalar("gamma")
delta = scalar("delta")
def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0)
logp_d1 = binom_log_prob(it0, delta, dt0)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
fn=seir_one_step,
sequences=[C_t, D_t],
outputs_info=[st0, et0, it0, None, None],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
st0_val, et0_val, it0_val = np.array(100.0), np.array(50.0), np.array(25.0)
beta_val, gamma_val, delta_val = (
np.array(0.277792),
np.array(0.135330),
np.array(0.108753),
)
C_t_val = np.array([3, 5, 8, 13, 21, 26, 10, 3] * 150, dtype=np.int32)
D_t_val = np.array([1, 2, 3, 7, 9, 11, 5, 1] * 150, dtype=np.int32)
assert C_t_val.shape == D_t_val.shape == C_t.type.shape == D_t.type.shape
test_input_vals = [
C_t_val,
D_t_val,
st0_val,
et0_val,
it0_val,
beta_val,
gamma_val,
delta_val,
]
loss_graph = logp_c_all.sum() + logp_d_all.sum()
return dict(
graph_inputs=[C_t, D_t, st0, et0, it0, beta, gamma, delta],
differentiable_vars=[st0, et0, it0, beta, gamma, delta],
test_input_vals=test_input_vals,
loss_graph=loss_graph,
)
def cyclical_reduction():
"""Setup a Scan implementation of the cyclical reduction algorithm.
This solves the matrix equation A @ X @ X + B @ X + C = 0 for X
Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187
"""
def stabilize(x, jitter=1e-16):
return x + jitter * pt.eye(x.shape[0])
def step(A0, A1, A2, A1_hat, norm, step_num, tol):
def cycle_step(A0, A1, A2, A1_hat, _norm, step_num):
tmp = pt.dot(
pt.vertical_stack(A0, A2),
pt.linalg.solve(
stabilize(A1),
pt.horizontal_stack(A0, A2),
assume_a="gen",
check_finite=False,
),
)
n = A0.shape[0]
idx_0 = pt.arange(n)
idx_1 = idx_0 + n
A1 = A1 - tmp[idx_0, :][:, idx_1] - tmp[idx_1, :][:, idx_0]
A0 = -tmp[idx_0, :][:, idx_0]
A2 = -tmp[idx_1, :][:, idx_1]
A1_hat = A1_hat - tmp[idx_1, :][:, idx_0]
A0_L1_norm = pt.linalg.norm(A0, ord=1)
return A0, A1, A2, A1_hat, A0_L1_norm, step_num + 1
return ifelse(
norm < tol,
(A0, A1, A2, A1_hat, norm, step_num),
cycle_step(A0, A1, A2, A1_hat, norm, step_num),
)
A = pt.matrix("A", shape=(20, 20))
B = pt.matrix("B", shape=(20, 20))
C = pt.matrix("C", shape=(20, 20))
norm = np.array(1e9, dtype="float64")
step_num = pt.zeros((), dtype="int32")
max_iter = 100
tol = 1e-7
(*_, A1_hat, norm, _n_steps), _ = scan(
step,
outputs_info=[A, B, C, B, norm, step_num],
non_sequences=[tol],
n_steps=max_iter,
)
A1_hat = A1_hat[-1]
T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False)
rng = np.random.default_rng(sum(map(ord, "cycle_reduction")))
n = A.type.shape[0]
A_test = rng.standard_normal(size=(n, n))
C_test = rng.standard_normal(size=(n, n))
# B must be invertible, so we make it symmetric positive-definite
B_rand = rng.standard_normal(size=(n, n))
B_test = B_rand @ B_rand.T + np.eye(n) * 1e-3
return dict(
graph_inputs=[A, B, C],
differentiable_vars=[A, B, C],
test_input_vals=[A_test, B_test, C_test],
loss_graph=pt.sum(T),
)
@pytest.mark.parametrize("gradient_backend", ["PYTENSOR", "JAX"])
@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both"))
@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp])
def test_scan_benchmark(model, mode, gradient_backend, benchmark):
if gradient_backend == "PYTENSOR" and mode in ("1backward", "2both"):
pytest.skip("PYTENSOR backend does not support backward mode yet")
model_dict = model()
graph_inputs = model_dict["graph_inputs"]
differentiable_vars = model_dict["differentiable_vars"]
loss_graph = model_dict["loss_graph"]
test_input_vals = model_dict["test_input_vals"]
if gradient_backend == "PYTENSOR":
backward_loss = pt.grad(
loss_graph,
wrt=differentiable_vars,
)
match mode:
# TODO: Restore original test separately
case "0forward":
graph_outputs = [loss_graph]
case "1backward":
graph_outputs = backward_loss
case "2both":
graph_outputs = [loss_graph, *backward_loss]
case _:
raise ValueError(f"Unknown mode: {mode}")
jax_fn, _ = compare_jax_and_py(
graph_inputs,
graph_outputs,
test_input_vals,
jax_mode="JAX",
)
jax_fn.trust_input = True
else: # gradient_backend == "JAX"
import jax
loss_fn_tuple = function(graph_inputs, loss_graph, mode="JAX").vm.jit_fn
def loss_fn(*args):
return loss_fn_tuple(*args)[0]
match mode:
case "0forward":
jax_fn = jax.jit(loss_fn_tuple)
case "1backward":
jax_fn = jax.jit(
jax.grad(loss_fn, argnums=tuple(range(len(graph_inputs))[2:]))
)
case "2both":
value_and_grad_fn = jax.value_and_grad(
loss_fn, argnums=tuple(range(len(graph_inputs))[2:])
)
@jax.jit
def jax_fn(*args):
loss, grads = value_and_grad_fn(*args)
return loss, *grads
case _:
raise ValueError(f"Unknown mode: {mode}")
def block_until_ready(*inputs, jax_fn=jax_fn):
return [o.block_until_ready() for o in jax_fn(*inputs)]
block_until_ready(*test_input_vals) # Warmup
benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论