提交 64bad71f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix JAX conversion for AdvancedIncSubtensor1

上级 0813ec0d
...@@ -425,40 +425,40 @@ def test_jax_Subtensors(): ...@@ -425,40 +425,40 @@ def test_jax_Subtensors():
# Basic indices # Basic indices
x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5)) x_tt = tt.arange(3 * 4 * 5).reshape((3, 4, 5))
out_tt = x_tt[1, 2, 0] out_tt = x_tt[1, 2, 0]
assert isinstance(out_tt.owner.op, tt.subtensor.Subtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[1:2, 1, :] out_tt = x_tt[1:2, 1, :]
assert isinstance(out_tt.owner.op, tt.subtensor.Subtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Boolean indices # Boolean indices
out_tt = x_tt[x_tt < 0] out_tt = x_tt[x_tt < 0]
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Advanced indexing # Advanced indexing
out_tt = x_tt[[1, 2]] out_tt = x_tt[[1, 2]]
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor1)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], [2, 3]] out_tt = x_tt[[1, 2], [2, 3]]
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# Advanced and basic indexing # Advanced and basic indexing
out_tt = x_tt[[1, 2], :] out_tt = x_tt[[1, 2], :]
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor1)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = x_tt[[1, 2], :, [3, 4]] out_tt = x_tt[[1, 2], :, [3, 4]]
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
...@@ -470,64 +470,92 @@ def test_jax_IncSubtensor(): ...@@ -470,64 +470,92 @@ def test_jax_IncSubtensor():
# "Set" basic indices # "Set" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.set_subtensor(x_tt[1, 2, 3], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[:2, 0, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt) out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Set" advanced indices # "Set" advanced indices
st_tt = tt.as_tensor_variable(
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(theano.config.floatX)
)
out_tt = tt.set_subtensor(x_tt[np.r_[0, 2]], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor1)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.set_subtensor(x_tt[[0, 2], 0, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt) out_tt = tt.set_subtensor(x_tt[[0, 2], 0, :3], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Set" boolean indices # "Set" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0 mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0) out_tt = tt.set_subtensor(x_tt[mask_tt], 0.0)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" basic indices # "Increment" basic indices
st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX)) st_tt = tt.as_tensor_variable(np.array(-10.0, dtype=theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt) out_tt = tt.inc_subtensor(x_tt[1, 2, 3], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[:2, 0, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt) out_tt = tt.set_subtensor(x_tt[0, 1:3, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.IncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" advanced indices # "Increment" advanced indices
st_tt = tt.as_tensor_variable(
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(theano.config.floatX)
)
out_tt = tt.inc_subtensor(x_tt[np.r_[0, 2]], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor1)
out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX)) st_tt = tt.as_tensor_variable(np.r_[-1.0, 0.0].astype(theano.config.floatX))
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt) out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, 0], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_tt = tt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt) out_tt = tt.inc_subtensor(x_tt[[0, 2], 0, :3], st_tt)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
# "Increment" boolean indices # "Increment" boolean indices
mask_tt = tt.as_tensor_variable(x_np) > 0 mask_tt = tt.as_tensor_variable(x_np) > 0
out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0) out_tt = tt.set_subtensor(x_tt[mask_tt], 1.0)
assert isinstance(out_tt.owner.op, tt.subtensor.AdvancedIncSubtensor)
out_fg = theano.gof.FunctionGraph([], [out_tt]) out_fg = theano.gof.FunctionGraph([], [out_tt])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
......
...@@ -623,7 +623,7 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops] ...@@ -623,7 +623,7 @@ _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op): def jax_funcify_IncSubtensor(op):
idx_list = op.idx_list idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = jax.ops.index_update
...@@ -632,7 +632,11 @@ def jax_funcify_IncSubtensor(op): ...@@ -632,7 +632,11 @@ def jax_funcify_IncSubtensor(op):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
_ilist = list(ilist) _ilist = list(ilist)
cdata = tuple(convert_indices(_ilist, idx) for idx in idx_list) cdata = (
tuple(convert_indices(_ilist, idx) for idx in idx_list)
if idx_list
else _ilist
)
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论