提交 a500f3a9 authored 作者: Frederic Bastien's avatar Frederic Bastien

implemented theano.gof.utils.get_theano_flag to parse the THEANO_FLAGS env variable and use it.

上级 c807a893
......@@ -5,6 +5,7 @@ from theano.gof.cutils import run_cthunk
from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer
from theano.gof.cc import OpWiseCLinker
from theano import gof
from theano.gof.utils import get_theano_flag
class ProfileMode(Mode):
def __init__(self, linker=default_linker, optimizer=default_optimizer):
......@@ -85,6 +86,10 @@ class ProfileMode(Mode):
param: n_ops_to_print the number of ops to print. Default 20.
"""
n_ops_to_print=int(get_theano_flag("n_ops_to_print", n_ops_to_print))
n_apply_to_print=int(get_theano_flag("n_apply_to_print", n_apply_to_print))
local_time = self.local_time[0]
compile_time = self.compile_time
apply_time = self.apply_time
......
......@@ -2,7 +2,7 @@
# import op
# import variable
import re
import re, os
def hashgen():
hashgen.next += 1
......@@ -369,3 +369,27 @@ def type_guard(type1):
return new_f
return wrap
def get_theano_flag(key, default=None):
"""
This function parse the environement variable THEANO_FLAGS.
if the variable don't exist return None
if the key is not in the variable return None
if the key is in the variable but without a value return True
if the key is in the variable with a value return the value
if the key appear many times, we return the last value
the THEANO_FLAGS environement variable is a list of key[=value] that is separated by comma.
"""
f=os.getenv("THEANO_FLAGS")
ret = default
key2=key+"="
for fl in f.split(','):
if fl==key:
ret = True
elif fl.startswith(key2):
ret = fl.split('=',1)[1]
return ret
......@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt')
from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof.utils import MethodNotDefined
from theano.gof.utils import MethodNotDefined, get_theano_flag
from elemwise import Elemwise, DimShuffle
from theano import scalar
import basic as T
......@@ -1339,14 +1339,11 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs
flags=os.getenv('THEANO_FLAGS',None)
if flags:
flags=flags.split(',')
if 'local_elemwise_fusion' in flags:
_logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion)
else:
_logger.debug("not enabling optimization: fusion elemwise")
if get_theano_flag('local_elemwise_fusion',False):
_logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion)
else:
_logger.debug("not enabling optimization: fusion elemwise")
# def make_composite(inputs, outputs):
# scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论