提交 4bcb0779 authored 作者: Frederic Bastien's avatar Frederic Bastien

if the env variable THEANO_FLAGS have the token local_elemwise_fusion we enable…

if the env variable THEANO_FLAGS have the token local_elemwise_fusion we enable the fusion. Each token is separated by comma.
上级 cdffaa87
...@@ -14,7 +14,7 @@ import inplace as I ...@@ -14,7 +14,7 @@ import inplace as I
import numpy as N import numpy as N
import operator import operator
import itertools import itertools
import sys import sys, os
from theano import compile #to register the optimizer built by this file from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
...@@ -1237,8 +1237,7 @@ def local_elemwise_fusion(node): ...@@ -1237,8 +1237,7 @@ def local_elemwise_fusion(node):
The number of dimension is validated at call time by theano itself. The number of dimension is validated at call time by theano itself.
TODO:The broadcast flag? TODO:The broadcast flag?
""" """
# TODO:implement Composite.__eq__ by using CLinker.cmodule_key() to compare the graph. #TODO: Merge with multiple output to merge when an inputs have multiple clients. This can't be done with a local optimiser.
#TODO: Merge when nb_clients>1? When this optimisation could introduce duplication of computation? When this will be faster?
if not isinstance(node.op, T.Elemwise): if not isinstance(node.op, T.Elemwise):
return False return False
...@@ -1305,8 +1304,16 @@ def local_elemwise_fusion(node): ...@@ -1305,8 +1304,16 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!" # print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs return n.outputs
#register_specialize(local_elemwise_fusion) flags=os.getenv('THEANO_FLAGS',None)
if flags:
flags=flags.split(',')
if 'local_elemwise_fusion' in flags:
print "Will fusion elemwise"
register_specialize(local_elemwise_fusion)
else:
print "Won't fuse elemwise"
# def make_composite(inputs, outputs): # def make_composite(inputs, outputs):
# scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs] # scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs]
# def transform(r): # def transform(r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论