提交 9ea670ca authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Remove type check on JAX multi-output tests

上级 8a7f59e1
......@@ -102,12 +102,12 @@ def test_jax_compile_ops():
x = theano.compile.ops.Shape()(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.Shape_i(1)(tt.as_tensor_variable(x_np))
x_fg = theano.gof.FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = theano.compile.ops.SpecifyShape()(tt.as_tensor_variable(x_np), (20, 3))
x_fg = theano.gof.FunctionGraph([], [x])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论