提交 eb506b5b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add optimization removing Rebroadcast when the pattern is not actually changed

上级 50924b23
...@@ -692,6 +692,25 @@ def local_inplace_setsubtensor(node): ...@@ -692,6 +692,25 @@ def local_inplace_setsubtensor(node):
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor, compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
####################
# Rebroadcast opts #
####################
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Rebroadcast])
def local_useless_rebroadcast(node):
"""
Remove Rebroadcast if id does not actually change the broadcasting pattern
"""
if isinstance(node.op, T.Rebroadcast):
x = node.inputs[0]
if numpy.all(x.broadcastable == node.outputs[0].broadcastable):
return [x]
################## ##################
# Reshape opts # # Reshape opts #
################## ##################
......
...@@ -1126,6 +1126,15 @@ def test_local_mul_specialize(): ...@@ -1126,6 +1126,15 @@ def test_local_mul_specialize():
assert nodes == [T.mul] assert nodes == [T.mul]
def test_local_useless_rebroadcast():
v1 = T.vector()
v2 = T.vector()
j = T.join(0, v1, v2)
f = theano.function([v1, v2], j)
f([1,2], [3,4,5])
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op, T.Rebroadcast)]) == 0
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论