Unverified 提交 0c203e99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #169 from junpenglao/jax_scan

Implement a JAX conversion for the Scan Op
...@@ -14,7 +14,7 @@ from theano.gof.op import get_test_value # noqa: E402 ...@@ -14,7 +14,7 @@ from theano.gof.op import get_test_value # noqa: E402
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_theano_flags(): def set_theano_flags():
with theano.change_flags(cxx="", compute_test_value="warn"): with theano.change_flags(cxx="", compute_test_value="ignore"):
yield yield
...@@ -111,7 +111,7 @@ def test_jax_Alloc(): ...@@ -111,7 +111,7 @@ def test_jax_Alloc():
x = tt.alloc(a, 20, 10) x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x]) x_fg = theano.gof.FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)]) compare_jax_and_py(x_fg, [np.ones(10, dtype=theano.config.floatX)])
def test_jax_compile_ops(): def test_jax_compile_ops():
...@@ -182,8 +182,8 @@ def test_jax_basic(): ...@@ -182,8 +182,8 @@ def test_jax_basic():
out_fg = theano.gof.FunctionGraph([x, y], [out]) out_fg = theano.gof.FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(tt.config.floatX), np.tile(np.arange(10), (10, 1)).astype(theano.config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(tt.config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(theano.config.floatX),
] ]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals) (jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
...@@ -201,43 +201,49 @@ def test_jax_basic(): ...@@ -201,43 +201,49 @@ def test_jax_basic():
out = tt.diagonal(x, 0) out = tt.diagonal(x, 0)
out_fg = theano.gof.FunctionGraph([x], [out]) out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)] out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
) )
out = tt.slinalg.cholesky(x) out = tt.slinalg.cholesky(x)
out_fg = theano.gof.FunctionGraph([x], [out]) out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)] out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
) )
# not sure why this isn't working yet with lower=False # not sure why this isn't working yet with lower=False
out = tt.slinalg.Cholesky(lower=False)(x) out = tt.slinalg.Cholesky(lower=False)(x)
out_fg = theano.gof.FunctionGraph([x], [out]) out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)] out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
) )
out = tt.slinalg.solve(x, b) out = tt.slinalg.solve(x, b)
out_fg = theano.gof.FunctionGraph([x, b], [out]) out_fg = theano.gof.FunctionGraph([x, b], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, out_fg,
[np.eye(10).astype(tt.config.floatX), np.arange(10).astype(tt.config.floatX)], [
np.eye(10).astype(theano.config.floatX),
np.arange(10).astype(theano.config.floatX),
],
) )
out = tt.nlinalg.alloc_diag(b) out = tt.nlinalg.alloc_diag(b)
out_fg = theano.gof.FunctionGraph([b], [out]) out_fg = theano.gof.FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(tt.config.floatX)]) compare_jax_and_py(out_fg, [np.arange(10).astype(theano.config.floatX)])
out = tt.nlinalg.det(x) out = tt.nlinalg.det(x)
out_fg = theano.gof.FunctionGraph([x], [out]) out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)] out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(theano.config.floatX)]
) )
out = tt.nlinalg.matrix_inverse(x) out = tt.nlinalg.matrix_inverse(x)
out_fg = theano.gof.FunctionGraph([x], [out]) out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)] out_fg,
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(theano.config.floatX)],
) )
...@@ -261,25 +267,25 @@ def test_jax_basic_multiout(): ...@@ -261,25 +267,25 @@ def test_jax_basic_multiout():
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
def assert_fn(x, y): def assert_fn(x, y):
np.testing.assert_allclose(x.astype(tt.config.floatX), y, rtol=1e-3) np.testing.assert_allclose(x.astype(theano.config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.eigh(x) outs = tt.nlinalg.eigh(x)
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="full") outs = tt.nlinalg.qr(x, mode="full")
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="reduced") outs = tt.nlinalg.qr(x, mode="reduced")
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.svd(x) outs = tt.nlinalg.svd(x)
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(theano.config.floatX)], assert_fn=assert_fn)
# Test that a single output of a multi-output `Op` can be used as input to # Test that a single output of a multi-output `Op` can be used as input to
# another `Op` # another `Op`
...@@ -290,18 +296,105 @@ def test_jax_basic_multiout(): ...@@ -290,18 +296,105 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [np.r_[1, 2]]) compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.skip(reason="Not fully implemented, yet.") def test_jax_scan_multiple_output():
def test_jax_scan(): """Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return tt.gammaln(n + 1) - tt.gammaln(k + 1) - tt.gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * tt.log(p) + (n - value) * tt.log(1 - p)
# sequences
tt_C = tt.ivector("C_t")
tt_D = tt.ivector("D_t")
# outputs_info (initial conditions)
st0 = tt.lscalar("s_t0")
et0 = tt.lscalar("e_t0")
it0 = tt.lscalar("i_t0")
logp_c = tt.scalar("logp_c")
logp_d = tt.scalar("logp_d")
# non_sequences
beta = tt.scalar("beta")
gamma = tt.scalar("gamma")
delta = tt.scalar("delta")
# TODO: Use random streams when their JAX conversions are implemented.
# trng = tt.shared_randomstreams.RandomStreams(1234)
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = theano.scan(
fn=seir_one_step,
sequences=[tt_C, tt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = theano.gof.FunctionGraph(
[tt_C, tt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=theano.config.floatX)
logp_d0 = np.array(0.0, dtype=theano.config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val, dtype=theano.config.floatX)
for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
theano.config.compute_test_value = "raise" def test_jax_scan_tap_output():
a_tt = tt.scalar("a") a_tt = tt.scalar("a")
a_tt.tag.test_value = 3.0
def input_step_fn(y_tm1, y_tm2, a): def input_step_fn(y_tm1, y_tm3, a):
y_tm1.name = "y_tm1" y_tm1.name = "y_tm1"
y_tm2.name = "y_tm2" y_tm3.name = "y_tm3"
res = (y_tm1 + y_tm2) * a res = (y_tm1 + y_tm3) * a
res.name = "y_t" res.name = "y_t"
return res return res
...@@ -310,9 +403,9 @@ def test_jax_scan(): ...@@ -310,9 +403,9 @@ def test_jax_scan():
outputs_info=[ outputs_info=[
{ {
"initial": tt.as_tensor_variable( "initial": tt.as_tensor_variable(
np.r_[-1.0, 0.0].astype(tt.config.floatX) np.r_[-1.0, 1.3, 0.0].astype(theano.config.floatX)
), ),
"taps": [-1, -2], "taps": [-1, -3],
}, },
], ],
non_sequences=[a_tt], non_sequences=[a_tt],
...@@ -322,31 +415,10 @@ def test_jax_scan(): ...@@ -322,31 +415,10 @@ def test_jax_scan():
y_scan_tt.name = "y" y_scan_tt.name = "y"
y_scan_tt.owner.inputs[0].name = "y_all" y_scan_tt.owner.inputs[0].name = "y_all"
theano_scan_fn = theano.function([], y_scan_tt, givens={a_tt: 3.0})
theano_res = theano_scan_fn()
#
# The equivalent JAX `scan`:
#
import jax
import jax.numpy as jnp
def jax_inner_scan(carry, x):
(y_tm1, y_tm2), a = carry
res = (y_tm1 + y_tm2) * a
return [jnp.array([res, y_tm1]), a], res
init_carry = [np.r_[0.0, -1.0].astype(tt.config.floatX), 3.0]
tmp, jax_res = jax.lax.scan(jax_inner_scan, init_carry, None, length=10)
assert np.allclose(jax_res, theano_res)
out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt]) out_fg = theano.gof.FunctionGraph([a_tt], [y_scan_tt])
test_input_vals = [np.array(10.0).astype(tt.config.floatX)] test_input_vals = [np.array(10.0).astype(theano.config.floatX)]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals)
raise AssertionError()
def test_jax_Subtensors(): def test_jax_Subtensors():
...@@ -392,16 +464,16 @@ def test_jax_Subtensors(): ...@@ -392,16 +464,16 @@ def test_jax_Subtensors():
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(tt.config.floatX) x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(theano.config.floatX)
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(tt.config.floatX) x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(theano.config.floatX)
# "Set" basic indices # "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -411,7 +483,7 @@ def test_jax_IncSubtensor(): ...@@ -411,7 +483,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Set" advanced indices # "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -428,12 +500,12 @@ def test_jax_IncSubtensor(): ...@@ -428,12 +500,12 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" basic indices # "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -443,7 +515,7 @@ def test_jax_IncSubtensor(): ...@@ -443,7 +515,7 @@ def test_jax_IncSubtensor():
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" advanced indices # "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -480,38 +552,38 @@ def test_jax_ifelse(): ...@@ -480,38 +552,38 @@ def test_jax_ifelse():
def test_jax_CAReduce(): def test_jax_CAReduce():
a_tt = tt.vector("a") a_tt = tt.vector("a")
a_tt.tag.test_value = np.r_[1, 2, 3].astype(tt.config.floatX) a_tt.tag.test_value = np.r_[1, 2, 3].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=None) x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(theano.config.floatX)])
a_tt = tt.matrix("a") a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX) a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.sum(a_tt, axis=0) x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.sum(a_tt, axis=1) x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
a_tt = tt.matrix("a") a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX) a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)
x = tt.prod(a_tt, axis=0) x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
x = tt.all(a_tt) x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(theano.config.floatX)])
def test_jax_MakeVector(): def test_jax_MakeVector():
...@@ -550,28 +622,32 @@ def test_jax_Dimshuffle(): ...@@ -550,28 +622,32 @@ def test_jax_Dimshuffle():
x = a_tt.T x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"]) x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(theano.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True]) a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,)) x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True]) a_tt = tt.tensor(dtype=theano.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt) x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(theano.config.floatX)])
def test_jax_variadic_Scalar(): def test_jax_variadic_Scalar():
mu = tt.vector("mu", dtype=tt.config.floatX) mu = tt.vector("mu", dtype=theano.config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(tt.config.floatX) mu.tag.test_value = np.r_[0.1, 1.1].astype(theano.config.floatX)
tau = tt.vector("tau", dtype=tt.config.floatX) tau = tt.vector("tau", dtype=theano.config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX) tau.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
res = -tau * mu res = -tau * mu
...@@ -589,13 +665,13 @@ def test_jax_variadic_Scalar(): ...@@ -589,13 +665,13 @@ def test_jax_variadic_Scalar():
def test_jax_logp(): def test_jax_logp():
mu = tt.vector("mu") mu = tt.vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(tt.config.floatX) mu.tag.test_value = np.r_[0.0, 0.0].astype(theano.config.floatX)
tau = tt.vector("tau") tau = tt.vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(tt.config.floatX) tau.tag.test_value = np.r_[1.0, 1.0].astype(theano.config.floatX)
sigma = tt.vector("sigma") sigma = tt.vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(tt.config.floatX) sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(theano.config.floatX)
value = tt.vector("value") value = tt.vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(tt.config.floatX) value.tag.test_value = np.r_[0.1, -10].astype(theano.config.floatX)
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0 logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0] conditions = [sigma > 0]
...@@ -609,9 +685,9 @@ def test_jax_logp(): ...@@ -609,9 +685,9 @@ def test_jax_logp():
def test_jax_multioutput(): def test_jax_multioutput():
x = tt.vector("x") x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX) x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
y = tt.vector("y") y = tt.vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(tt.config.floatX) y.tag.test_value = np.r_[3.0, 4.0].astype(theano.config.floatX)
w = tt.cosh(x ** 2 + y / 3.0) w = tt.cosh(x ** 2 + y / 3.0)
v = tt.cosh(x / 3.0 + y ** 2) v = tt.cosh(x / 3.0 + y ** 2)
...@@ -623,7 +699,7 @@ def test_jax_multioutput(): ...@@ -623,7 +699,7 @@ def test_jax_multioutput():
def test_nnet(): def test_nnet():
x = tt.vector("x") x = tt.vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(tt.config.floatX) x.tag.test_value = np.r_[1.0, 2.0].astype(theano.config.floatX)
out = tt.nnet.sigmoid(x) out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
......
...@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs): def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs] func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args) return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func)) jax_funcs.append(update_wrapper(jax_func, return_func))
...@@ -420,7 +421,7 @@ def jax_funcify_Scan(op): ...@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
) )
# `outer_inputs` is a list with the following composite form: # `outer_inputs` is a list with the following composite form:
...@@ -435,9 +436,9 @@ def jax_funcify_Scan(op): ...@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps = scan_args.n_steps n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs) # TODO: mit_mots
mit_mot_in_slices = []
# TODO: sit_sots
mit_sot_in_slices = [] mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0] neg_taps = [abs(t) for t in tap if t < 0]
...@@ -447,7 +448,15 @@ def jax_funcify_Scan(op): ...@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice = seq[: max_neg + max_pos] init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice) mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs] sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x): def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared # `carry` contains all inner-output taps, non_seqs, and shared
...@@ -470,15 +479,22 @@ def jax_funcify_Scan(op): ...@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot # + inner_in_sit_sot
# + inner_in_shared # + inner_in_shared
# + inner_in_non_seqs # + inner_in_non_seqs
inner_scan_inputs = [ inner_in_mit_sot_flatten = []
inner_in_seqs, for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_mot, inner_in_mit_sot_flatten.extend(array[index])
inner_in_mit_sot,
inner_in_sit_sot, inner_scan_inputs = sum(
inner_in_non_seqs, [
] inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
raise NotImplementedError()
return inner_scan_inputs return inner_scan_inputs
def inner_scan_outs_to_jax_outs( def inner_scan_outs_to_jax_outs(
...@@ -486,47 +502,66 @@ def jax_funcify_Scan(op): ...@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry, old_carry,
inner_scan_outs, inner_scan_outs,
): ):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
( (
outer_out_mit_mot, inner_in_mit_mot,
outer_out_mit_sot, inner_in_mit_sot,
outer_out_sit_sot, inner_in_sit_sot,
outer_out_nit_sot, inner_in_shared,
outer_out_shared, inner_in_non_seqs,
cond, ) = old_carry
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs] def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared # This should contain all inner-output taps, non_seqs, and shared
# terms # terms
carry = [ if not inner_in_sit_sot:
outer_out_mit_mot, inner_out_sit_sot = []
outer_out_mit_sot, else:
outer_out_sit_sot, inner_out_sit_sot = inner_scan_outs
outer_out_shared, new_carry = (
outer_out_non_seqs, inner_in_mit_mot,
] inner_out_mit_sot,
# This should contain all inner-outputs that produce inner_out_sit_sot,
# outer-outputs inner_in_shared,
y = [] inner_in_non_seqs,
)
raise NotImplementedError() return new_carry
return (carry, y)
def jax_inner_func(carry, x): def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x) inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args) inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs) new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, y return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan return scan
......
...@@ -1075,8 +1075,9 @@ class scan_args: ...@@ -1075,8 +1075,9 @@ class scan_args:
if k in info: if k in info:
self.other_info[k] = info[k] self.other_info[k] = info[k]
inner_inputs = property( @property
lambda self: ( def inner_inputs(self):
return (
self.inner_in_seqs self.inner_in_seqs
+ sum(self.inner_in_mit_mot, []) + sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, []) + sum(self.inner_in_mit_sot, [])
...@@ -1084,10 +1085,10 @@ class scan_args: ...@@ -1084,10 +1085,10 @@ class scan_args:
+ self.inner_in_shared + self.inner_in_shared
+ self.inner_in_non_seqs + self.inner_in_non_seqs
) )
)
outer_inputs = property( @property
lambda self: ( def outer_inputs(self):
return (
[self.n_steps] [self.n_steps]
+ self.outer_in_seqs + self.outer_in_seqs
+ self.outer_in_mit_mot + self.outer_in_mit_mot
...@@ -1097,10 +1098,10 @@ class scan_args: ...@@ -1097,10 +1098,10 @@ class scan_args:
+ self.outer_in_nit_sot + self.outer_in_nit_sot
+ self.outer_in_non_seqs + self.outer_in_non_seqs
) )
)
inner_outputs = property( @property
lambda self: ( def inner_outputs(self):
return (
sum(self.inner_out_mit_mot, []) sum(self.inner_out_mit_mot, [])
+ self.inner_out_mit_sot + self.inner_out_mit_sot
+ self.inner_out_sit_sot + self.inner_out_sit_sot
...@@ -1108,20 +1109,20 @@ class scan_args: ...@@ -1108,20 +1109,20 @@ class scan_args:
+ self.inner_out_shared + self.inner_out_shared
+ self.cond + self.cond
) )
)
outer_outputs = property( @property
lambda self: ( def outer_outputs(self):
return (
self.outer_out_mit_mot self.outer_out_mit_mot
+ self.outer_out_mit_sot + self.outer_out_mit_sot
+ self.outer_out_sit_sot + self.outer_out_sit_sot
+ self.outer_out_nit_sot + self.outer_out_nit_sot
+ self.outer_out_shared + self.outer_out_shared
) )
)
info = property( @property
lambda self: OrderedDict( def info(self):
return OrderedDict(
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
...@@ -1137,7 +1138,6 @@ class scan_args: ...@@ -1137,7 +1138,6 @@ class scan_args:
mit_mot_out_slices=self.mit_mot_out_slices, mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info, **self.other_info,
) )
)
def __copy__(self): def __copy__(self):
res = object.__new__(type(self)) res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论