提交 2edc7339 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Allow jax_funcify_Join to work without omnistaging

JAX can check for correct axis itself
上级 b0ba476b
......@@ -727,12 +727,6 @@ def jax_funcify_Join(op, **kwargs):
return tensors[view]
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
)
return jnp.concatenate(tensors, axis=axis)
return join
......
......@@ -860,10 +860,6 @@ def test_jax_Dimshuffle():
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Join():
a = matrix("a")
b = matrix("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论