提交 b2d82a01 authored 作者: Frederic Bastien's avatar Frederic Bastien

put the fusion of elemwise after the gemm optimizer as this make the gemm…

put the fusion of elemwise after the gemm optimizer as this make the gemm optimizer not able to make its work.
上级 7ea849d0
......@@ -597,6 +597,7 @@ class GemmOptimizer(Optimizer):
#TODO: retry other applications of gemm (see comment in _gemm_from_node
pass
#neede to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
compile.optdb.register('inplace_gemm', GemmOptimizer(), 70.00, 'fast_run', 'inplace', 'gemm')
......
......@@ -20,7 +20,8 @@ import sys, os
from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler
# Utilities
def out2in(*local_opts):
......@@ -1231,7 +1232,6 @@ register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# # Loop fusion #
# ###############
@gof.local_optimizer([T.Elemwise, T.Elemwise])
def local_elemwise_fusion(node):
"""
As part of specialisation, we fusion two consecutif elemwise op of the same shape.
......@@ -1339,9 +1339,38 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs
class FusionOptimizer(Optimizer):
"""Graph optimizer for Fusion of elemwise operations"""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
env.extend(DestroyHandler())
def apply(self, env):
did_something = True
while did_something:
nodelist = list(env.toposort())
did_something = False
for node in nodelist:
new_outputs = local_elemwise_fusion(node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
env.replace_all_validate(
zip(node.outputs, new_outputs),
reason = self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
#TODO: retry other applications of gemm (see comment in _gemm_from_node
pass
if config.getboolean('tensor_opt.local_elemwise_fusion'):
_logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion)
compile.optdb.register('elemwise_fusion', FusionOptimizer(), 71.00, 'fast_run', 'fusion')
else:
_logger.debug("not enabling optimization: fusion elemwise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论