提交 39e7bcad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

In Rebroadcast, keep only the axis flags actually changing the broadcasting pattern

上级 fc29c78b
...@@ -709,7 +709,20 @@ def local_useless_rebroadcast(node): ...@@ -709,7 +709,20 @@ def local_useless_rebroadcast(node):
if isinstance(node.op, T.Rebroadcast): if isinstance(node.op, T.Rebroadcast):
x = node.inputs[0] x = node.inputs[0]
if numpy.all(x.broadcastable == node.outputs[0].broadcastable): if numpy.all(x.broadcastable == node.outputs[0].broadcastable):
# No broadcastable flag was modified
return [x] return [x]
else:
# Keep the flags that modify something
new_axis = {}
for dim, bc in node.op.axis.items():
if x.broadcastable[dim] != bc:
new_axis[dim] = bc
if new_axis == node.op.axis:
# All flags are useful
return
else:
return [T.Rebroadcast(*new_axis.items())(x)]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
......
...@@ -1146,7 +1146,9 @@ class T_Rebroadcast(unittest.TestCase): ...@@ -1146,7 +1146,9 @@ class T_Rebroadcast(unittest.TestCase):
f = theano.function([m], v, mode=mode) f = theano.function([m], v, mode=mode)
f([[76]]) f([[76]])
e = f.maker.env.toposort() e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, T.Rebroadcast)]) == 1 rebroadcast_nodes = [n for n in e if isinstance(n.op, T.Rebroadcast)]
assert len(rebroadcast_nodes) == 1
assert rebroadcast_nodes[0].op.axis == {0: True}
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论