提交 e5373112 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Fix type inference in Elemwise when inputs have 0 shape

上级 84e69fc8
...@@ -419,30 +419,34 @@ class Elemwise(OpenMPOp): ...@@ -419,30 +419,34 @@ class Elemwise(OpenMPOp):
# broadcastable bit in turn. # broadcastable bit in turn.
def get_most_specialized_shape(shapes): def get_most_specialized_shape(shapes):
if None not in shapes: shapes = set(shapes)
# We could check if shapes are valid under broadcasting # All shapes are the same
# len(set(dims).discard(1)) <= 1 if len(shapes) == 1:
return max(shapes) return tuple(shapes)[0]
known_shapes = [shape for shape in shapes if shape is not None] # Only valid indeterminate case
if known_shapes: if shapes == {None, 1}:
largest_known_shape = max(known_shapes) return None
# If largest known shape is 1, and there is an unknown shape, we don't
# know the final shape, because this could be broadcasted shapes.discard(1)
if largest_known_shape > 1: shapes.discard(None)
# Again, we could check that known shapes are valid under broacasting if len(shapes) > 1:
return largest_known_shape raise ValueError
return tuple(shapes)[0]
return None
# it is multiplied by nout because Elemwise supports multiple outputs # it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them) # (nout of them)
out_shapes = [ try:
[ out_shapes = [
get_most_specialized_shape(shape) [
for shape in zip(*[input.type.shape for input in inputs]) get_most_specialized_shape(shape)
] for shape in zip(*[inp.type.shape for inp in inputs])
] * shadow.nout ]
] * shadow.nout
except ValueError:
raise ValueError(
f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}"
)
# inplace_pattern maps output idx -> input idx # inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
......
...@@ -856,8 +856,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -856,8 +856,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert all(isinstance(v.type, TensorType) for v in out_shape) assert all(isinstance(v.type, TensorType) for v in out_shape)
def test_static_shape_unary(self): def test_static_shape_unary(self):
x = tensor("float64", shape=(None, 1, 5)) x = tensor("float64", shape=(None, 0, 1, 5))
exp(x).type.shape == (None, 1, 5) exp(x).type.shape == (None, 0, 1, 5)
def test_static_shape_binary(self): def test_static_shape_binary(self):
x = tensor("float64", shape=(None, 5)) x = tensor("float64", shape=(None, 5))
...@@ -876,6 +876,19 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -876,6 +876,19 @@ class TestElemwise(unittest_tools.InferShapeTester):
y = tensor("float64", shape=(1, 1)) y = tensor("float64", shape=(1, 1))
assert (x + y).type.shape == (None, 1) assert (x + y).type.shape == (None, 1)
x = tensor("float64", shape=(0, 0, 0))
y = tensor("float64", shape=(0, 1, None))
assert (x + y).type.shape == (0, 0, 0)
def test_invalid_static_shape(self):
x = tensor("float64", shape=(2,))
y = tensor("float64", shape=(3,))
with pytest.raises(
ValueError,
match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"),
):
x + y
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论