提交 ca26e81a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Test works on more recent versions of JAX

上级 995b6cbc
import numpy as np import numpy as np
import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
...@@ -53,9 +52,6 @@ def test_jax_Reshape_concrete_shape(): ...@@ -53,9 +52,6 @@ def test_jax_Reshape_concrete_shape():
compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(
reason="`shape_pt` should be specified as a static argument", strict=True
)
def test_jax_Reshape_shape_graph_input(): def test_jax_Reshape_shape_graph_input():
a = vector("a") a = vector("a")
shape_pt = iscalar("b") shape_pt = iscalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论