提交 0ea34350 authored 作者: junpenglao's avatar junpenglao 提交者: Brandon T. Willard

Implement JAX conversion for Scan Op

上级 d222fa16
......@@ -290,18 +290,104 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan():
def test_jax_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
theano.config.compute_test_value = "raise"
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return tt.gammaln(n + 1) - tt.gammaln(k + 1) - tt.gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * tt.log(p) + (n - value) * tt.log(1 - p)
# sequences
tt_C = tt.ivector("C_t")
tt_D = tt.ivector("D_t")
# outputs_info (initial conditions)
st0 = tt.lscalar("s_t0")
et0 = tt.lscalar("e_t0")
it0 = tt.lscalar("i_t0")
logp_c = tt.scalar("logp_c")
logp_d = tt.scalar("logp_d")
# non_sequences
beta = tt.scalar("beta")
gamma = tt.scalar("gamma")
delta = tt.scalar("delta")
# TODO: Use random streams when their JAX conversions are implemented.
# trng = tt.shared_randomstreams.RandomStreams(1234)
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = theano.scan(
fn=seir_one_step,
sequences=[tt_C, tt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = theano.gof.FunctionGraph(
[tt_C, tt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0).astype(tt.config.floatX)
logp_d0 = np.array(0.0).astype(tt.config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val).astype(tt.config.floatX) for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
def test_jax_scan_tap_output():
a_tt = tt.scalar("a")
a_tt.tag.test_value = 3.0
def input_step_fn(y_tm1, y_tm2, a):
def input_step_fn(y_tm1, y_tm3, a):
y_tm1.name = "y_tm1"
y_tm2.name = "y_tm2"
res = (y_tm1 + y_tm2) * a
y_tm3.name = "y_tm3"
res = (y_tm1 + y_tm3) * a
res.name = "y_t"
return res
......@@ -310,9 +396,9 @@ def test_jax_scan():
outputs_info=[
{
"initial": tt.as_tensor_variable(
np.r_[-1.0, 0.0].astype(tt.config.floatX)
np.r_[-1.0, 1.3, 0.0].astype(tt.config.floatX)
),
"taps": [-1, -2],
"taps": [-1, -3],
},
],
non_sequences=[a_tt],
......@@ -322,31 +408,10 @@ def test_jax_scan():
y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all"
theano_scan_fn = theano.function([], y_scan_tt, givens={a_tt: 3.0})
theano_res = theano_scan_fn()
#
# The equivalent JAX `scan`:
#
import jax
import jax.numpy as jnp
def jax_inner_scan(carry, x):
(y_tm1, y_tm2), a = carry
res = (y_tm1 + y_tm2) * a
return [jnp.array([res, y_tm1]), a], res
init_carry = [np.r_[0.0, -1.0].astype(tt.config.floatX), 3.0]
tmp, jax_res = jax.lax.scan(jax_inner_scan, init_carry, None, length=10)
assert np.allclose(jax_res, theano_res)
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
raise AssertionError()
compare_jax_and_py(out_fg, test_input_vals)
def test_jax_Subtensors():
......
......@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
......@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
......@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: mit_mots
mit_mot_in_slices = []
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
......@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
......@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[index])
inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
......@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = old_carry
def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
if not inner_in_sit_sot:
inner_out_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
new_carry = (
inner_in_mit_mot,
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_in_non_seqs,
)
raise NotImplementedError()
return (carry, y)
return new_carry
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论