Unverified 提交 d72a2890 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Fix failing JAX Split test (#1646)

上级 934306f2
......@@ -119,20 +119,22 @@ def jax_funcify_Split(op: Split, node, **kwargs):
def split(x, axis, splits):
if constant_axis is not None:
axis = constant_axis
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
if (splits < 0).any():
raise ValueError("Split sizes cannot be negative")
else:
cumsum_splits = jnp.cumsum(splits[:-1])
if len(splits) != op.len_splits:
raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]:
if constant_axis is not None and constant_splits is not None:
if splits.sum() != x.shape[axis]:
raise ValueError(
f"Split sizes do not sum up to input length along axis: {x.shape[axis]}"
)
if np.any(splits < 0):
raise ValueError("Split sizes cannot be negative")
return jnp.split(x, cumsum_splits, axis=axis)
......
......@@ -182,16 +182,17 @@ class TestJaxSplit:
UserWarning, match="Split node does not have constant split positions."
):
fn = pytensor.function([a], a_splits, mode="JAX")
# It raises an informative ConcretizationTypeError, but there's an AttributeError that surpasses it
with pytest.raises(AttributeError):
# This test used to raise AttributeError in previous versions of JAX.
# Now it raises `TracerIntegerConversionError`.
# We accept both errors for backwards compatibility.
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
split_axis = iscalar("split_axis")
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=split_axis)
with pytest.warns(UserWarning, match="Split node does not have constant axis."):
fn = pytensor.function([a, split_axis], a_splits, mode="JAX")
# Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
# Both errors are included for backwards compatibility
# Same reasoning as above to accept both errors.
with pytest.raises((AttributeError, errors.TracerIntegerConversionError)):
fn(np.zeros((6, 6), dtype=pytensor.config.floatX), 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论