提交 8ec730aa authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't fuse elemwise where the scalar don't have a c_code implemented.

上级 e568c558
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from theano import gof from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof.utils import MethodNotDefined
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
...@@ -1246,7 +1247,23 @@ def local_elemwise_fusion(node): ...@@ -1246,7 +1247,23 @@ def local_elemwise_fusion(node):
s_inputs = []#inputs of the new scalar op. s_inputs = []#inputs of the new scalar op.
s_g=[]#graph of scalar, what will by done in the inner loop. s_g=[]#graph of scalar, what will by done in the inner loop.
for i in node.inputs: for i in node.inputs:
do_fusion = False
catch = False
if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)<=1: if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)<=1:
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True
try:
i.owner.op.scalar_op.c_code(i,"test_presence_of_c_code",
i.owner.inputs,i.owner.outputs,{})
except MethodNotDefined:
catch = True
except NotImplementedError:
catch = True
if catch:
print "OPTIMISATION WARNING: ",i.owner.op.scalar_op,"don't implement the c_code fonction. This is not fast and disable the fusion of loop."
do_fusion=False
if do_fusion:
if len(i.clients)>1: if len(i.clients)>1:
#should we put this in the first if, then we would go to the elif to don't fuse it? #should we put this in the first if, then we would go to the elif to don't fuse it?
#if one of the inputs have more then 1 clients and it is an intermediate result. We don't fuse. #if one of the inputs have more then 1 clients and it is an intermediate result. We don't fuse.
......
...@@ -874,7 +874,7 @@ class test_fusion(unittest.TestCase): ...@@ -874,7 +874,7 @@ class test_fusion(unittest.TestCase):
(fx+fy+theano.tensor.exp(fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv+fyv+numpy.exp(fzv),'float32'),#35 (fx+fy+theano.tensor.exp(fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv+fyv+numpy.exp(fzv),'float32'),#35
(fx-fy-fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv-fyv-fzv,'float32'), (fx-fy-fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv-fyv-fzv,'float32'),
(fx-(fy/fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv/fzv),'float32'), (fx-(fy/fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv/fzv),'float32'),
# (fx-(fy%fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv%fzv),'float32'),#TODO: c_code not implemented for % (fx-(fy%fz),(fx,fy,fz),(fxv,fyv,fzv),2,fxv-(fyv%fzv),'float32'),#TODO: c_code not implemented for %
(fx-(fy>fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv>fzv),'float32'), (fx-(fy>fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv>fzv),'float32'),
(fx-(fy>=fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv>=fzv),'float32'), (fx-(fy>=fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv>=fzv),'float32'),
(fx-(fy<fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv<fzv),'float32'), (fx-(fy<fz),(fx,fy,fz),(fxv,fyv,fzv),1,fxv-(fyv<fzv),'float32'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论