提交 7367e8d0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in Dimshuffles created by Elemwise

上级 a86efe55
...@@ -130,6 +130,10 @@ class DimShuffle(ExternalCOp): ...@@ -130,6 +130,10 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name) super().__init__([self.c_func_file], self.c_func_name)
self.input_broadcastable = tuple(input_broadcastable) self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.new_order = tuple(new_order) self.new_order = tuple(new_order)
self.inplace = True self.inplace = True
...@@ -411,10 +415,9 @@ class Elemwise(OpenMPOp): ...@@ -411,10 +415,9 @@ class Elemwise(OpenMPOp):
if not difference: if not difference:
args.append(input) args.append(input)
else: else:
# TODO: use LComplete instead
args.append( args.append(
dim_shuffle( dim_shuffle(
tuple(1 if s == 1 else None for s in input.type.shape), input.type.broadcastable,
["x"] * difference + list(range(length)), ["x"] * difference + list(range(length)),
)(input) )(input)
) )
......
...@@ -188,6 +188,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -188,6 +188,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle([0, 1, "x"]) y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1) assert y.type.shape == (1, 2, 1)
def test_valid_input_broadcastable(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)
with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
DimShuffle([None, None], (1, 0))
class TestBroadcast: class TestBroadcast:
# this is to allow other types to reuse this class to test their ops # this is to allow other types to reuse this class to test their ops
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论