提交 9e9c34b9 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Return input directly with redundant shape in specify_shape

上级 837d9f9a
...@@ -556,6 +556,16 @@ def specify_shape( ...@@ -556,6 +556,16 @@ def specify_shape(
except ValueError: except ValueError:
raise ValueError("Shape vector must have fixed dimensions") raise ValueError("Shape vector must have fixed dimensions")
# If the specified shape is already encoded in the input static shape, do nothing
# This ignores Aesara constants in shape
x = at.as_tensor_variable(x)
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape) if s is not None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if not new_shape_info and len(shape) == x.type.ndim:
return x
return _specify_shape(x, *shape) return _specify_shape(x, *shape)
......
...@@ -498,6 +498,18 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -498,6 +498,18 @@ class TestSpecifyShape(utt.InferShapeTester):
SpecifyShape, SpecifyShape,
) )
def test_direct_return(self):
"""Test that when specified shape does not provide new information, input is
returned directly."""
x = TensorType("float64", shape=(1, 2, None))("x")
assert specify_shape(x, (1, 2, None)) is x
assert specify_shape(x, (None, None, None)) is x
assert specify_shape(x, (1, 2, 3)) is not x
assert specify_shape(x, (None, None, 3)) is not x
assert specify_shape(x, (1, 3, None)) is not x
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论