提交 eb506b5b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add optimization removing Rebroadcast when the pattern is not actually changed

上级 50924b23
......@@ -692,6 +692,25 @@ def local_inplace_setsubtensor(node):
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
####################
# Rebroadcast opts #
####################
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Rebroadcast])
def local_useless_rebroadcast(node):
"""
Remove Rebroadcast if id does not actually change the broadcasting pattern
"""
if isinstance(node.op, T.Rebroadcast):
x = node.inputs[0]
if numpy.all(x.broadcastable == node.outputs[0].broadcastable):
return [x]
##################
# Reshape opts #
##################
......
......@@ -1126,6 +1126,15 @@ def test_local_mul_specialize():
assert nodes == [T.mul]
def test_local_useless_rebroadcast():
v1 = T.vector()
v2 = T.vector()
j = T.join(0, v1, v2)
f = theano.function([v1, v2], j)
f([1,2], [3,4,5])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, T.Rebroadcast)]) == 0
if __name__ == '__main__':
# unittest.main()
test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论