提交 5d4e9e07 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in JAX test

上级 3496e688
...@@ -29,7 +29,7 @@ def test_jax_Alloc(): ...@@ -29,7 +29,7 @@ def test_jax_Alloc():
x = ptb.AllocEmpty("float32")(2, 3) x = ptb.AllocEmpty("float32")(2, 3)
def compare_shape_dtype(x, y): def compare_shape_dtype(x, y):
np.testing.assert_array_equal(x, y, strict=True) assert x.shape == y.shape and x.dtype == y.dtype
compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype) compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论