提交 057afede authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Fix failing test due to change in JAX behavior

上级 ce0b503c
......@@ -176,7 +176,7 @@ class TestJaxSplit:
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
with pytest.raises(AttributeError):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
......@@ -184,7 +184,9 @@ class TestJaxSplit:
a_splits = at.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
with pytest.raises(jax.errors.TracerIntegerConversionError):
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility
with pytest.raises((AttributeError, jax.errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论