Unverified 提交 90ef59eb authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: GitHub

Add jaxification for linear algebra operations (#59)

上级 4c72bf9e
...@@ -7,6 +7,8 @@ import theano.tensor as tt ...@@ -7,6 +7,8 @@ import theano.tensor as tt
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
from functools import partial # noqa: E402
from theano.gof.op import get_test_value # noqa: E402 from theano.gof.op import get_test_value # noqa: E402
...@@ -16,10 +18,10 @@ def set_theano_flags(): ...@@ -16,10 +18,10 @@ def set_theano_flags():
yield yield
def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose): def compare_jax_and_py(
# jax_mode = theano.compile.Mode(linker="jax") fgraph, inputs, assert_fn=partial(np.testing.assert_allclose, rtol=1e-4)
jax_mode = "JAX" ):
theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=jax_mode) theano_jax_fn = theano.function(fgraph.inputs, fgraph.outputs, mode="JAX")
jax_res = theano_jax_fn(*inputs) jax_res = theano_jax_fn(*inputs)
if isinstance(jax_res, list): if isinstance(jax_res, list):
...@@ -31,7 +33,11 @@ def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose): ...@@ -31,7 +33,11 @@ def compare_jax_and_py(fgraph, inputs, cmp_fn=np.allclose):
theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode) theano_py_fn = theano.function(fgraph.inputs, fgraph.outputs, mode=py_mode)
py_res = theano_py_fn(*inputs) py_res = theano_py_fn(*inputs)
assert cmp_fn(jax_res, py_res) if len(fgraph.outputs) > 1:
for j, p in zip(jax_res, py_res):
assert_fn(j, p)
else:
assert_fn(jax_res, py_res)
return jax_res return jax_res
...@@ -57,49 +63,49 @@ def test_jax_Alloc(): ...@@ -57,49 +63,49 @@ def test_jax_Alloc():
(y,) = y (y,) = y
return x.shape == y.shape and x.dtype == y.dtype return x.shape == y.shape and x.dtype == y.dtype
(jax_res,) = compare_jax_and_py(x_fg, [], cmp_fn=compare_shape_dtype) compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype)
a = tt.scalar("a") a = tt.scalar("a")
x = tt.alloc(a, 20) x = tt.alloc(a, 20)
x_fg = theano.gof.FunctionGraph([a], [x]) x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [10.0]) compare_jax_and_py(x_fg, [10.0])
a = tt.vector("a") a = tt.vector("a")
x = tt.alloc(a, 20, 10) x = tt.alloc(a, 20, 10)
x_fg = theano.gof.FunctionGraph([a], [x]) x_fg = theano.gof.FunctionGraph([a], [x])
(jax_res,) = compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)]) compare_jax_and_py(x_fg, [np.ones(10, dtype=tt.config.floatX)])
def test_jax_compile_ops(): def test_jax_compile_ops():
x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1)) x = theano.compile.ops.DeepCopyOp()(tt.as_tensor_variable(1.1))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 3)) x_np = np.zeros((20, 3))
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np)) x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np)) x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3)) x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"): with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3)) x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (2, 3))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = theano.compile.ops.Rebroadcast((0, False), (1, True), (2, False))( x = theano.compile.ops.Rebroadcast((0, False), (1, True), (2, False))(
...@@ -107,7 +113,7 @@ def test_jax_compile_ops(): ...@@ -107,7 +113,7 @@ def test_jax_compile_ops():
) )
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with theano.change_flags(compute_test_value="off"): with theano.change_flags(compute_test_value="off"):
x = theano.compile.ops.Rebroadcast((0, True), (1, False), (2, False))( x = theano.compile.ops.Rebroadcast((0, True), (1, False), (2, False))(
...@@ -116,17 +122,18 @@ def test_jax_compile_ops(): ...@@ -116,17 +122,18 @@ def test_jax_compile_ops():
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
with pytest.raises(ValueError): with pytest.raises(ValueError):
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x = theano.compile.ops.ViewOp()(tt.as_tensor_variable(x_np)) x = theano.compile.ops.ViewOp()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
def test_jax_basic(): def test_jax_basic():
x = tt.matrix("x") x = tt.matrix("x")
y = tt.matrix("y") y = tt.matrix("y")
b = tt.vector("b")
# `ScalarOp` # `ScalarOp`
z = tt.cosh(x ** 2 + y / 3.0) z = tt.cosh(x ** 2 + y / 3.0)
...@@ -153,7 +160,82 @@ def test_jax_basic(): ...@@ -153,7 +160,82 @@ def test_jax_basic():
out = tt.clip(x, y, 5) out = tt.clip(x, y, 5)
out_fg = theano.gof.FunctionGraph([x, y], [out]) out_fg = theano.gof.FunctionGraph([x, y], [out])
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals)
out = tt.diagonal(x, 0)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
)
out = tt.slinalg.cholesky(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
)
# not sure why this isn't working yet with lower=False
out = tt.slinalg.Cholesky(lower=False)(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
)
out = tt.slinalg.solve(x, b)
out_fg = theano.gof.FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[np.eye(10).astype(tt.config.floatX), np.arange(10).astype(tt.config.floatX)],
)
out = tt.nlinalg.alloc_diag(b)
out_fg = theano.gof.FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(tt.config.floatX)])
out = tt.nlinalg.det(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(tt.config.floatX)]
)
out = tt.nlinalg.matrix_inverse(x)
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(tt.config.floatX)]
)
def test_jax_basic_multiout():
np.random.seed(213234)
M = np.random.normal(size=(3, 3))
X = M.dot(M.T)
x = tt.matrix("x")
outs = tt.nlinalg.eig(x)
out_fg = theano.gof.FunctionGraph([x], outs)
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(tt.config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.eigh(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="full")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.qr(x, mode="reduced")
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
outs = tt.nlinalg.svd(x)
out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
@pytest.mark.skip(reason="Not fully implemented, yet.") @pytest.mark.skip(reason="Not fully implemented, yet.")
...@@ -221,40 +303,40 @@ def test_jax_Subtensors(): ...@@ -221,40 +303,40 @@ def test_jax_Subtensors():
out_tt = x_tt[1, 2, 0] out_tt = x_tt[1, 2, 0]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[1:2, 1, :] out_tt = x_tt[1:2, 1, :]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Boolean indices # Boolean indices
out_tt = x_tt[x_tt < 0] out_tt = x_tt[x_tt < 0]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Advanced indexing # Advanced indexing
out_tt = x_tt[[1, 2]] out_tt = x_tt[[1, 2]]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], [2, 3]] out_tt = x_tt[[1, 2], [2, 3]]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Advanced and basic indexing # Advanced and basic indexing
out_tt = x_tt[[1, 2], :] out_tt = x_tt[[1, 2], :]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], :, [3, 4]] out_tt = x_tt[[1, 2], :, [3, 4]]
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensor(): def test_jax_IncSubtensor():
...@@ -265,65 +347,65 @@ def test_jax_IncSubtensor(): ...@@ -265,65 +347,65 @@ def test_jax_IncSubtensor():
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt) out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Set" advanced indices # "Set" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt) out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Set" boolean indices # "Set" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0 mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0) out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" basic indices # "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt) out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" advanced indices # "Increment" advanced indices
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(tt.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt) out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" boolean indices # "Increment" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0 mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0) out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
(jax_res,) = compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
def test_jax_ifelse(): def test_jax_ifelse():
...@@ -334,12 +416,12 @@ def test_jax_ifelse(): ...@@ -334,12 +416,12 @@ def test_jax_ifelse():
x = theano.ifelse.ifelse(np.array(True), true_vals, false_vals) x = theano.ifelse.ifelse(np.array(True), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals) x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
def test_jax_CAReduce(): def test_jax_CAReduce():
...@@ -349,7 +431,7 @@ def test_jax_CAReduce(): ...@@ -349,7 +431,7 @@ def test_jax_CAReduce():
x = tt.sum(a_tt, axis=None) x = tt.sum(a_tt, axis=None)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(tt.config.floatX)])
a_tt = tt.matrix("a") a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX) a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
...@@ -357,12 +439,12 @@ def test_jax_CAReduce(): ...@@ -357,12 +439,12 @@ def test_jax_CAReduce():
x = tt.sum(a_tt, axis=0) x = tt.sum(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.sum(a_tt, axis=1) x = tt.sum(a_tt, axis=1)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
a_tt = tt.matrix("a") a_tt = tt.matrix("a")
a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX) a_tt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)
...@@ -370,19 +452,19 @@ def test_jax_CAReduce(): ...@@ -370,19 +452,19 @@ def test_jax_CAReduce():
x = tt.prod(a_tt, axis=0) x = tt.prod(a_tt, axis=0)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
x = tt.all(a_tt) x = tt.all(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(tt.config.floatX)])
def test_jax_MakeVector(): def test_jax_MakeVector():
x = tt.opt.make_vector(1, 2, 3) x = tt.opt.make_vector(1, 2, 3)
x_fg = theano.gof.FunctionGraph([], [x]) x_fg = theano.gof.FunctionGraph([], [x])
_ = compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
def test_jax_Reshape(): def test_jax_Reshape():
...@@ -390,9 +472,7 @@ def test_jax_Reshape(): ...@@ -390,9 +472,7 @@ def test_jax_Reshape():
x = tt.basic.reshape(a_tt, (2, 2)) x = tt.basic.reshape(a_tt, (2, 2))
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py( compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)]
)
def test_jax_Reshape_omnistaging(): def test_jax_Reshape_omnistaging():
...@@ -402,7 +482,7 @@ def test_jax_Reshape_omnistaging(): ...@@ -402,7 +482,7 @@ def test_jax_Reshape_omnistaging():
x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3)) x = tt.basic.reshape(a_tt, (a_tt.shape[0] // 2, a_tt.shape[0] // 3))
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)]) compare_jax_and_py(x_fg, [np.empty((6,)).astype(theano.config.floatX)])
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
...@@ -410,25 +490,21 @@ def test_jax_Dimshuffle(): ...@@ -410,25 +490,21 @@ def test_jax_Dimshuffle():
x = a_tt.T x = a_tt.T
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py( compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
x = a_tt.dimshuffle([0, 1, "x"]) x = a_tt.dimshuffle([0, 1, "x"])
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py( compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)])
x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(tt.config.floatX)]
)
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True]) a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = a_tt.dimshuffle((0,)) x = a_tt.dimshuffle((0,))
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True]) a_tt = tt.tensor(dtype=tt.config.floatX, broadcastable=[False, True])
x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt) x = tt.elemwise.DimShuffle([False, True], (0,), inplace=True)(a_tt)
x_fg = theano.gof.FunctionGraph([a_tt], [x]) x_fg = theano.gof.FunctionGraph([a_tt], [x])
_ = compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(tt.config.floatX)])
def test_jax_variadic_Scalar(): def test_jax_variadic_Scalar():
...@@ -441,13 +517,13 @@ def test_jax_variadic_Scalar(): ...@@ -441,13 +517,13 @@ def test_jax_variadic_Scalar():
fgraph = theano.gof.FunctionGraph([mu, tau], [res]) fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2 res = -tau * (tau - mu) ** 2
fgraph = theano.gof.FunctionGraph([mu, tau], [res]) fgraph = theano.gof.FunctionGraph([mu, tau], [res])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp(): def test_jax_logp():
...@@ -468,7 +544,7 @@ def test_jax_logp(): ...@@ -468,7 +544,7 @@ def test_jax_logp():
fgraph = theano.gof.FunctionGraph([mu, tau, sigma, value], [normal_logp]) fgraph = theano.gof.FunctionGraph([mu, tau, sigma, value], [normal_logp])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_multioutput(): def test_jax_multioutput():
...@@ -482,7 +558,7 @@ def test_jax_multioutput(): ...@@ -482,7 +558,7 @@ def test_jax_multioutput():
fgraph = theano.gof.FunctionGraph([x, y], [w, v]) fgraph = theano.gof.FunctionGraph([x, y], [w, v])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_nnet(): def test_nnet():
...@@ -491,15 +567,15 @@ def test_nnet(): ...@@ -491,15 +567,15 @@ def test_nnet():
out = tt.nnet.sigmoid(x) out = tt.nnet.sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.ultra_fast_sigmoid(x) out = tt.nnet.ultra_fast_sigmoid(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.nnet.softplus(x) out = tt.nnet.softplus(x)
fgraph = theano.gof.FunctionGraph([x], [out]) fgraph = theano.gof.FunctionGraph([x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_tensor_basics(): def test_tensor_basics():
...@@ -519,15 +595,15 @@ def test_tensor_basics(): ...@@ -519,15 +595,15 @@ def test_tensor_basics():
# leave the expression alone. # leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y out = y.dot(alpha * A).dot(x) + beta * y
fgraph = theano.gof.FunctionGraph([y, x, A, alpha, beta], [out]) fgraph = theano.gof.FunctionGraph([y, x, A, alpha, beta], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.maximum(y, x) out = tt.maximum(y, x)
fgraph = theano.gof.FunctionGraph([y, x], [out]) fgraph = theano.gof.FunctionGraph([y, x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.max(y) out = tt.max(y)
fgraph = theano.gof.FunctionGraph([y], [out]) fgraph = theano.gof.FunctionGraph([y], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") @pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
...@@ -537,7 +613,7 @@ def test_arange(): ...@@ -537,7 +613,7 @@ def test_arange():
out = tt.arange(a) out = tt.arange(a)
fgraph = theano.gof.FunctionGraph([a], [out]) fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_identity(): def test_identity():
...@@ -546,4 +622,4 @@ def test_identity(): ...@@ -546,4 +622,4 @@ def test_identity():
out = theano.scalar.basic.identity(a) out = theano.scalar.basic.identity(a)
fgraph = theano.gof.FunctionGraph([a], [out]) fgraph = theano.gof.FunctionGraph([a], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -74,15 +74,12 @@ class JAXLinker(PerformLinker): ...@@ -74,15 +74,12 @@ class JAXLinker(PerformLinker):
thunk_outputs = [storage_map[n] for n in node.outputs] thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions if not isinstance(jax_funcs, Sequence):
if len(node.outputs) > 1: jax_funcs = [jax_funcs]
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [ jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
] ]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
def thunk( def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
...@@ -92,6 +89,14 @@ class JAXLinker(PerformLinker): ...@@ -92,6 +89,14 @@ class JAXLinker(PerformLinker):
for jax_impl_jit in jax_impl_jits for jax_impl_jit in jax_impl_jits
] ]
if len(jax_impl_jits) < len(node.outputs):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs = outputs[0]
for o_node, o_storage, o_val in zip( for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs node.outputs, thunk_outputs, outputs
): ):
......
...@@ -2,6 +2,7 @@ import theano ...@@ -2,6 +2,7 @@ import theano
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.scipy as jsp
from warnings import warn from warnings import warn
from functools import update_wrapper, reduce from functools import update_wrapper, reduce
...@@ -49,12 +50,36 @@ from theano.tensor.opt import MakeVector ...@@ -49,12 +50,36 @@ from theano.tensor.opt import MakeVector
from theano.tensor.nnet.sigm import ScalarSoftplus from theano.tensor.nnet.sigm import ScalarSoftplus
from theano.tensor.nlinalg import (
Det,
Eig,
Eigh,
MatrixInverse,
QRFull,
QRIncomplete,
SVD,
ExtractDiag,
AllocDiag,
)
from theano.tensor.slinalg import (
Cholesky,
Solve,
)
if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
jax.config.update("jax_enable_x64", False)
# XXX: Enabling this will break some shape-based functionality, and severely # XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted. # limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md # See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
jax.config.disable_omnistaging() # Older versions < 0.2.0 do not have this flag so we don't need to set it.
jax.config.update("jax_enable_x64", True) try:
jax.config.disable_omnistaging()
except AttributeError:
pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor) subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor) incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
...@@ -629,3 +654,112 @@ def jax_funcify_Join(op): ...@@ -629,3 +654,112 @@ def jax_funcify_Join(op):
return jnp.concatenate(tensors, axis=axis) return jnp.concatenate(tensors, axis=axis)
return join return join
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op):
offset = op.offset
axis1 = op.axis1
axis2 = op.axis2
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
return extract_diag
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op):
lower = op.lower
def cholesky(a, lower=lower):
return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
def alloc_diag(x):
return jnp.diag(x)
return alloc_diag
@jax_funcify.register(Solve)
def jax_funcify_Solve(op):
if op.A_structure == "lower_triangular":
lower = True
else:
lower = False
def solve(a, b, lower=lower):
return jsp.linalg.solve(a, b, lower=lower)
return solve
@jax_funcify.register(Det)
def jax_funcify_Det(op):
def det(x):
return jnp.linalg.det(x)
return det
@jax_funcify.register(Eig)
def jax_funcify_Eig(op):
def eig(x):
return jnp.linalg.eig(x)
return eig
@jax_funcify.register(Eigh)
def jax_funcify_Eigh(op):
uplo = op.UPLO
def eigh(x, uplo=uplo):
return jnp.linalg.eigh(x, UPLO=uplo)
return eigh
@jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op):
def matrix_inverse(x):
return jnp.linalg.inv(x)
return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op):
mode = op.mode
def qr_incomplete(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_incomplete
@jax_funcify.register(SVD)
def jax_funcify_SVD(op):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
def svd(x, full_matrices=full_matrices, compute_uv=compute_uv):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论