提交 2edc7339 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Allow jax_funcify_Join to work without omnistaging

JAX can check for correct axis itself
上级 b0ba476b
...@@ -727,12 +727,6 @@ def jax_funcify_Join(op, **kwargs): ...@@ -727,12 +727,6 @@ def jax_funcify_Join(op, **kwargs):
return tensors[view] return tensors[view]
else: else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
)
return jnp.concatenate(tensors, axis=axis) return jnp.concatenate(tensors, axis=axis)
return join return join
......
...@@ -860,10 +860,6 @@ def test_jax_Dimshuffle(): ...@@ -860,10 +860,6 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Join(): def test_jax_Join():
a = matrix("a") a = matrix("a")
b = matrix("b") b = matrix("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论