提交 e5373112 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Fix type inference in Elemwise when inputs have 0 shape

上级 84e69fc8
......@@ -419,30 +419,34 @@ class Elemwise(OpenMPOp):
# broadcastable bit in turn.
def get_most_specialized_shape(shapes):
if None not in shapes:
# We could check if shapes are valid under broadcasting
# len(set(dims).discard(1)) <= 1
return max(shapes)
known_shapes = [shape for shape in shapes if shape is not None]
if known_shapes:
largest_known_shape = max(known_shapes)
# If largest known shape is 1, and there is an unknown shape, we don't
# know the final shape, because this could be broadcasted
if largest_known_shape > 1:
# Again, we could check that known shapes are valid under broacasting
return largest_known_shape
return None
shapes = set(shapes)
# All shapes are the same
if len(shapes) == 1:
return tuple(shapes)[0]
# Only valid indeterminate case
if shapes == {None, 1}:
return None
shapes.discard(1)
shapes.discard(None)
if len(shapes) > 1:
raise ValueError
return tuple(shapes)[0]
# it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them)
out_shapes = [
[
get_most_specialized_shape(shape)
for shape in zip(*[input.type.shape for input in inputs])
]
] * shadow.nout
try:
out_shapes = [
[
get_most_specialized_shape(shape)
for shape in zip(*[inp.type.shape for inp in inputs])
]
] * shadow.nout
except ValueError:
raise ValueError(
f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}"
)
# inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern
......
......@@ -856,8 +856,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 1, 5))
exp(x).type.shape == (None, 1, 5)
x = tensor("float64", shape=(None, 0, 1, 5))
exp(x).type.shape == (None, 0, 1, 5)
def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5))
......@@ -876,6 +876,19 @@ class TestElemwise(unittest_tools.InferShapeTester):
y = tensor("float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1)
x = tensor("float64", shape=(0, 0, 0))
y = tensor("float64", shape=(0, 1, None))
assert (x + y).type.shape == (0, 0, 0)
def test_invalid_static_shape(self):
x = tensor("float64", shape=(2,))
y = tensor("float64", shape=(3,))
with pytest.raises(
ValueError,
match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"),
):
x + y
def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论