提交 ca26e81a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Test works on more recent versions of JAX

上级 995b6cbc
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
......@@ -53,9 +52,6 @@ def test_jax_Reshape_concrete_shape():
compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(
reason="`shape_pt` should be specified as a static argument", strict=True
)
def test_jax_Reshape_shape_graph_input():
a = vector("a")
shape_pt = iscalar("b")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论