提交 4bcb0779 authored 作者: Frederic Bastien's avatar Frederic Bastien

if the env variable THEANO_FLAGS have the token local_elemwise_fusion we enable…

if the env variable THEANO_FLAGS have the token local_elemwise_fusion we enable the fusion. Each token is separated by comma.
上级 cdffaa87
......@@ -14,7 +14,7 @@ import inplace as I
import numpy as N
import operator
import itertools
import sys
import sys, os
from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
......@@ -1237,8 +1237,7 @@ def local_elemwise_fusion(node):
The number of dimension is validated at call time by theano itself.
TODO:The broadcast flag?
"""
# TODO:implement Composite.__eq__ by using CLinker.cmodule_key() to compare the graph.
#TODO: Merge when nb_clients>1? When this optimisation could introduce duplication of computation? When this will be faster?
#TODO: Merge with multiple output to merge when an inputs have multiple clients. This can't be done with a local optimiser.
if not isinstance(node.op, T.Elemwise):
return False
......@@ -1305,8 +1304,16 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs
#register_specialize(local_elemwise_fusion)
flags=os.getenv('THEANO_FLAGS',None)
if flags:
flags=flags.split(',')
if 'local_elemwise_fusion' in flags:
print "Will fusion elemwise"
register_specialize(local_elemwise_fusion)
else:
print "Won't fuse elemwise"
# def make_composite(inputs, outputs):
# scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs]
# def transform(r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论