提交 54fba943 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Make non-strict zip strict in tensor/shape.py

上级 e8db7169
......@@ -578,11 +578,15 @@ def specify_shape(
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if not new_shape_info and len(shape) == x.type.ndim:
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)
new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
return x
return _specify_shape(x, *shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论