Unverified 提交 0c203e99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #169 from junpenglao/jax_scan

Implement a JAX conversion for the Scan Op
......@@ -14,7 +14,7 @@ from theano.gof.op import get_test_value # noqa: E402
@pytest.fixture(scope="module", autouse=True)
def set_theano_flags():
with theano.change_flags(cxx="", compute_test_value="warn"):
with theano.change_flags(cxx="", compute_test_value="ignore"):
yield
......@@ -111,7 +111,7 @@ def test_jax_Alloc():
x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)])
compare_jax_and_py(x_fg, [np.ones(10, dtype=theano.config.floatX)])
def test_jax_compile_ops():
......@@ -182,8 +182,8 @@ def test_jax_basic():
out_fg = theano.gof.FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(tt.config.floatX),
np.tile(np.arange(10), (10, 1)).astype(theano.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(theano.config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
......@@ -201,43 +201,49 @@ def test_jax_basic():
out = tt.diagonal(x, 0)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
)
out = tt.slinalg.cholesky(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
# not sure why this isn't working yet with lower=False
out = tt.slinalg.Cholesky(lower=False)(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
out = tt.slinalg.solve(x, b)
out_fg = theano.gof.FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[np.eye(10).astype(tt.config.floatX), np.arange(10).astype(tt.config.floatX)],
[
np.eye(10).astype(theano.config.floatX),
np.arange(10).astype(theano.config.floatX),
],
)
out = tt.nlinalg.alloc_diag(b)
out_fg = theano.gof.FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(tt.config.floatX)])
compare_jax_and_py(out_fg, [np.arange(10).astype(theano.config.floatX)])
out = tt.nlinalg.det(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
)
out = tt.nlinalg.matrix_inverse(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
)
......@@ -261,25 +267,25 @@ def test_jax_basic_multiout():
out_fg = theano.gof.FunctionGraph([x], outs)
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(tt.config.floatX), y, rtol=1e-3)
np.testing.assert_allclose(x.astype(theano.config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.eigh(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="full")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="reduced")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.svd(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
......@@ -290,18 +296,105 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan():
def test_jax_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return tt.gammaln(n + 1) - tt.gammaln(k + 1) - tt.gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * tt.log(p) + (n - value) * tt.log(1 - p)
# sequences
tt_C = tt.ivector("C_t")
tt_D = tt.ivector("D_t")
# outputs_info (initial conditions)
st0 = tt.lscalar("s_t0")
et0 = tt.lscalar("e_t0")
it0 = tt.lscalar("i_t0")
logp_c = tt.scalar("logp_c")
logp_d = tt.scalar("logp_d")
# non_sequences
beta = tt.scalar("beta")
gamma = tt.scalar("gamma")
delta = tt.scalar("delta")
# TODO: Use random streams when their JAX conversions are implemented.
# trng = tt.shared_randomstreams.RandomStreams(1234)
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = theano.scan(
fn=seir_one_step,
sequences=[tt_C, tt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = theano.gof.FunctionGraph(
[tt_C, tt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=theano.config.floatX)
logp_d0 = np.array(0.0, dtype=theano.config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val, dtype=theano.config.floatX)
for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
theano.config.compute_test_value = "raise"
def test_jax_scan_tap_output():
a_tt = tt.scalar("a")
a_tt.tag.test_value = 3.0
def input_step_fn(y_tm1, y_tm2, a):
def input_step_fn(y_tm1, y_tm3, a):
y_tm1.name = "y_tm1"
y_tm2.name = "y_tm2"
res = (y_tm1 + y_tm2) * a
y_tm3.name = "y_tm3"
res = (y_tm1 + y_tm3) * a
res.name = "y_t"
return res
......@@ -310,9 +403,9 @@ def test_jax_scan():
outputs_info=[
{
"initial": tt.as_tensor_variable(
np.r_[-1.0, 0.0].astype(tt.config.floatX)
np.r_[-1.0, 1.3, 0.0].astype(theano.config.floatX)
),
"taps": [-1, -2],
"taps": [-1, -3],
},
],
non_sequences=[a_tt],
......@@ -322,31 +415,10 @@ def test_jax_scan():
y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all"
theano_scan_fn = theano.function([], y_scan_tt, givens={a_tt: 3.0})
theano_res = theano_scan_fn()
#
# The equivalent JAX `scan`:
#
import jax
import jax.numpy as jnp
def jax_inner_scan(carry, x):
(y_tm1, y_tm2), a = carry
res = (y_tm1 + y_tm2) * a
return [jnp.array([res, y_tm1]), a], res
init_carry = [np.r_[0.0, -1.0].astype(tt.config.floatX), 3.0]
tmp, jax_res = jax.lax.scan(jax_inner_scan, init_carry, None, length=10)
assert np.allclose(jax_res, theano_res)
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
raise AssertionError()
test_input_vals = [np.array(10.0).astype(theano.config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
def test_jax_Subtensors():
......@@ -392,16 +464,16 @@ def test_jax_Subtensors():
def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(tt.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX)
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(theano.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(theano.config.floatX)
# "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -411,7 +483,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -428,12 +500,12 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -443,7 +515,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
......@@ -480,38 +552,38 @@ def test_jax_ifelse():
def test_jax_CAReduce():
a_tt = tt.vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(tt.config.floatX)
a_tt.tag.test_value = np.r_[1, 2, 3].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(theano.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
def test_jax_MakeVector():
......@@ -550,28 +622,32 @@ def test_jax_Dimshuffle():
x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
def test_jax_variadic_Scalar():
mu = tt.vector("mu", dtype=tt.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(tt.config.floatX)
tau = tt.vector("tau", dtype=tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
mu = tt.vector("mu", dtype=theano.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(theano.config.floatX)
tau = tt.vector("tau", dtype=theano.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
res = -tau * mu
......@@ -589,13 +665,13 @@ def test_jax_variadic_Scalar():
def test_jax_logp():
mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(tt.config.floatX)
mu.tag.test_value = np.r_[0.0, 0.0].astype(theano.config.floatX)
tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(tt.config.floatX)
tau.tag.test_value = np.r_[1.0, 1.0].astype(theano.config.floatX)
sigma = tt.vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(tt.config.floatX)
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(theano.config.floatX)
value = tt.vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(tt.config.floatX)
value.tag.test_value = np.r_[0.1, -10].astype(theano.config.floatX)
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
......@@ -609,9 +685,9 @@ def test_jax_logp():
def test_jax_multioutput():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
y = tt.vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(tt.config.floatX)
y.tag.test_value = np.r_[3.0, 4.0].astype(theano.config.floatX)
w = tt.cosh(x ** 2 + y / 3.0)
v = tt.cosh(x / 3.0 + y ** 2)
......@@ -623,7 +699,7 @@ def test_jax_multioutput():
def test_nnet():
x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX)
x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out])
......
......@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
......@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
......@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: mit_mots
mit_mot_in_slices = []
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
......@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
......@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[index])
inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
......@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = old_carry
def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
if not inner_in_sit_sot:
inner_out_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
new_carry = (
inner_in_mit_mot,
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_in_non_seqs,
)
raise NotImplementedError()
return (carry, y)
return new_carry
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan
......
......@@ -1075,8 +1075,9 @@ class scan_args:
if k in info:
self.other_info[k] = info[k]
inner_inputs = property(
lambda self: (
@property
def inner_inputs(self):
return (
self.inner_in_seqs
+ sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, [])
......@@ -1084,10 +1085,10 @@ class scan_args:
+ self.inner_in_shared
+ self.inner_in_non_seqs
)
)
outer_inputs = property(
lambda self: (
@property
def outer_inputs(self):
return (
[self.n_steps]
+ self.outer_in_seqs
+ self.outer_in_mit_mot
......@@ -1097,10 +1098,10 @@ class scan_args:
+ self.outer_in_nit_sot
+ self.outer_in_non_seqs
)
)
inner_outputs = property(
lambda self: (
@property
def inner_outputs(self):
return (
sum(self.inner_out_mit_mot, [])
+ self.inner_out_mit_sot
+ self.inner_out_sit_sot
......@@ -1108,20 +1109,20 @@ class scan_args:
+ self.inner_out_shared
+ self.cond
)
)
outer_outputs = property(
lambda self: (
@property
def outer_outputs(self):
return (
self.outer_out_mit_mot
+ self.outer_out_mit_sot
+ self.outer_out_sit_sot
+ self.outer_out_nit_sot
+ self.outer_out_shared
)
)
info = property(
lambda self: OrderedDict(
@property
def info(self):
return OrderedDict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
......@@ -1137,7 +1138,6 @@ class scan_args:
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info,
)
)
def __copy__(self):
res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论