提交 a266b4c1 authored 作者: Frederic Bastien's avatar Frederic Bastien

apply rebroadcast opt during graph compilation in fct unbroadcast and addbroadcast.

This is done to don't insert Rerebroadcast op that do nothing after those optimization are applied. This happen with GpuJoin.
上级 3a3a18f8
...@@ -2682,15 +2682,20 @@ class Rebroadcast(Op): ...@@ -2682,15 +2682,20 @@ class Rebroadcast(Op):
def addbroadcast(x, *axes): def addbroadcast(x, *axes):
""" """
Make the input broadcastable in the specified axes. Make the input broadcastable in the specified axes.
We apply the opt here to don't pollute the graph especially during the gpu optimization
""" """
return Rebroadcast(*[(axis, True) for axis in axes])(x) rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval)
def unbroadcast(x, *axes): def unbroadcast(x, *axes):
""" """
Make the input impossible to broadcast in the specified axes. Make the input impossible to broadcast in the specified axes.
"""
return Rebroadcast(*[(axis, False) for axis in axes])(x)
We apply the opt here to don't pollute the graph especially during the gpu optimization
"""
rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval)
class Join(Op): class Join(Op):
......
...@@ -806,6 +806,30 @@ def local_rebroadcast_lift(node): ...@@ -806,6 +806,30 @@ def local_rebroadcast_lift(node):
rval = [T.Rebroadcast(*axis.items())(iinput)] rval = [T.Rebroadcast(*axis.items())(iinput)]
return rval return rval
def apply_rebroadcast_opt(rval):
"""
Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift.
:param rval: a Variable
:retrun: a Variable. The same if not optimisation can be applied.
"""
changed = True
while changed and rval.owner:
changed = False
rval2 = theano.tensor.opt.local_useless_rebroadcast.transform(rval.owner)
if rval2:
assert len(rval2)==1
rval = rval2[0]
changed = True
if rval.owner:
rval2 = theano.tensor.opt.local_rebroadcast_lift.transform(rval.owner)
if rval2:
assert len(rval2)==1
rval = rval2[0]
changed = True
return rval
################## ##################
# Reshape opts # # Reshape opts #
......
...@@ -2507,6 +2507,48 @@ def test_autocast(): ...@@ -2507,6 +2507,48 @@ def test_autocast():
finally: finally:
ac.__exit__() ac.__exit__()
def test_unbroadcast_addbroadcast():
"""
test that the unbroadcast fct don't insert not needed broadcast
and fuse consecutive Rebroadcast op
"""
x=matrix()
assert unbroadcast(x,0) is x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is x
assert unbroadcast(x,0,1) is x
assert addbroadcast(x,0) is not x
assert addbroadcast(x,1) is not x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,0),0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x
x=row()
assert unbroadcast(x,0) is not x
assert unbroadcast(x,1) is x
assert unbroadcast(x,1,0) is not x
assert unbroadcast(x,0,1) is not x
assert addbroadcast(x,0) is x
assert addbroadcast(x,1).owner.inputs[0] is x
assert addbroadcast(x,1,0).owner.inputs[0] is x
assert addbroadcast(x,0,1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,1),1) is x
assert addbroadcast(unbroadcast(x,1),1) is not x
#the first broadcast is remove the broadcast, so the second
#should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x
#test that consecutive Rebroadcast op are fused
x=TensorType(dtype = 'float64', broadcastable = (True,True))()
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
unittest.main() unittest.main()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论