提交 7367e8d0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in Dimshuffles created by Elemwise

上级 a86efe55
......@@ -130,6 +130,10 @@ class DimShuffle(ExternalCOp):
super().__init__([self.c_func_file], self.c_func_name)
self.input_broadcastable = tuple(input_broadcastable)
if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable):
raise ValueError(
f"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self.new_order = tuple(new_order)
self.inplace = True
......@@ -411,10 +415,9 @@ class Elemwise(OpenMPOp):
if not difference:
args.append(input)
else:
# TODO: use LComplete instead
args.append(
dim_shuffle(
tuple(1 if s == 1 else None for s in input.type.shape),
input.type.broadcastable,
["x"] * difference + list(range(length)),
)(input)
)
......
......@@ -188,6 +188,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y = x.dimshuffle([0, 1, "x"])
assert y.type.shape == (1, 2, 1)
def test_valid_input_broadcastable(self):
assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False)
with pytest.raises(ValueError, match="input_broadcastable must be boolean"):
DimShuffle([None, None], (1, 0))
class TestBroadcast:
# this is to allow other types to reuse this class to test their ops
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论