提交 acbe8d39 authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't merge elemwise that cause duplicate computation of broadcasted element.

上级 a9177e2b
...@@ -4593,7 +4593,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4593,7 +4593,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
# we still want to fusion. So we take the set. # we still want to fusion. So we take the set.
if (i.owner and if (i.owner and
isinstance(i.owner.op, OP) and isinstance(i.owner.op, OP) and
len(set([n for n, idx in i.clients])) == 1): len(set([n for n, idx in i.clients])) == 1 and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable):
do_fusion = True do_fusion = True
try: try:
......
...@@ -871,12 +871,15 @@ class test_fusion(unittest.TestCase): ...@@ -871,12 +871,15 @@ class test_fusion(unittest.TestCase):
ix, iy, iz = [theano.tensor.tensor(dtype='int32', ix, iy, iz = [theano.tensor.tensor(dtype='int32',
broadcastable=[False] * len(shp), broadcastable=[False] * len(shp),
name=n) for n in 'xyz'] name=n) for n in 'xyz']
fv = fvector('r') fv = fvector('v')
fs = fscalar('s')
fwv = my_init(shp, 'float32', 1) fwv = my_init(shp, 'float32', 1)
fxv = my_init(shp, 'float32', 2) fxv = my_init(shp, 'float32', 2)
fyv = my_init(shp, 'float32', 3) fyv = my_init(shp, 'float32', 3)
fzv = my_init(shp, 'float32', 4) fzv = my_init(shp, 'float32', 4)
fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32') fvv = theano._asarray(numpy.random.rand(shp[0]), dtype='float32')
fsv = numpy.asarray(numpy.random.rand(), dtype='float32')
dwv = my_init(shp, 'float64', 5) dwv = my_init(shp, 'float64', 5)
ixv = theano._asarray(my_init(shp, num=60), dtype='int32') ixv = theano._asarray(my_init(shp, num=60), dtype='int32')
iyv = theano._asarray(my_init(shp, num=70), dtype='int32') iyv = theano._asarray(my_init(shp, num=70), dtype='int32')
...@@ -1035,7 +1038,13 @@ class test_fusion(unittest.TestCase): ...@@ -1035,7 +1038,13 @@ class test_fusion(unittest.TestCase):
(theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,), (theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,),
1,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'), 1,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'),
(theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv, (theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,
fyv),1,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'), fyv),1,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'), # 70
#Cases with different broadcast pattern. They should not
#be merged as this would duplicate computation
#The graph should have 2 elemwise and 1 dimshuffle
(fx*theano.tensor.sin(fs),(fx,fs),(fxv,
fsv),3,fxv*numpy.sin(fsv),'float32'),
] ]
if slice: if slice:
cases = cases[slice] cases = cases[slice]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论