提交 a500f3a9 authored 作者: Frederic Bastien's avatar Frederic Bastien

implemented theano.gof.utils.get_theano_flag to parse the THEANO_FLAGS env variable and use it.

上级 c807a893
...@@ -5,6 +5,7 @@ from theano.gof.cutils import run_cthunk ...@@ -5,6 +5,7 @@ from theano.gof.cutils import run_cthunk
from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer from theano.compile.mode import Mode, register_mode, predefined_modes, predefined_linkers, predefined_optimizers, default_linker, default_optimizer
from theano.gof.cc import OpWiseCLinker from theano.gof.cc import OpWiseCLinker
from theano import gof from theano import gof
from theano.gof.utils import get_theano_flag
class ProfileMode(Mode): class ProfileMode(Mode):
def __init__(self, linker=default_linker, optimizer=default_optimizer): def __init__(self, linker=default_linker, optimizer=default_optimizer):
...@@ -85,6 +86,10 @@ class ProfileMode(Mode): ...@@ -85,6 +86,10 @@ class ProfileMode(Mode):
param: n_ops_to_print the number of ops to print. Default 20. param: n_ops_to_print the number of ops to print. Default 20.
""" """
n_ops_to_print=int(get_theano_flag("n_ops_to_print", n_ops_to_print))
n_apply_to_print=int(get_theano_flag("n_apply_to_print", n_apply_to_print))
local_time = self.local_time[0] local_time = self.local_time[0]
compile_time = self.compile_time compile_time = self.compile_time
apply_time = self.apply_time apply_time = self.apply_time
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# import op # import op
# import variable # import variable
import re import re, os
def hashgen(): def hashgen():
hashgen.next += 1 hashgen.next += 1
...@@ -369,3 +369,27 @@ def type_guard(type1): ...@@ -369,3 +369,27 @@ def type_guard(type1):
return new_f return new_f
return wrap return wrap
def get_theano_flag(key, default=None):
"""
This function parse the environement variable THEANO_FLAGS.
if the variable don't exist return None
if the key is not in the variable return None
if the key is in the variable but without a value return True
if the key is in the variable with a value return the value
if the key appear many times, we return the last value
the THEANO_FLAGS environement variable is a list of key[=value] that is separated by comma.
"""
f=os.getenv("THEANO_FLAGS")
ret = default
key2=key+"="
for fl in f.split(','):
if fl==key:
ret = True
elif fl.startswith(key2):
ret = fl.split('=',1)[1]
return ret
...@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt') ...@@ -8,7 +8,7 @@ _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined, get_theano_flag
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
from theano import scalar from theano import scalar
import basic as T import basic as T
...@@ -1339,13 +1339,10 @@ def local_elemwise_fusion(node): ...@@ -1339,13 +1339,10 @@ def local_elemwise_fusion(node):
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!" # print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs return n.outputs
flags=os.getenv('THEANO_FLAGS',None) if get_theano_flag('local_elemwise_fusion',False):
if flags:
flags=flags.split(',')
if 'local_elemwise_fusion' in flags:
_logger.debug("enabling optimization: fusion elemwise") _logger.debug("enabling optimization: fusion elemwise")
register_specialize(local_elemwise_fusion) register_specialize(local_elemwise_fusion)
else: else:
_logger.debug("not enabling optimization: fusion elemwise") _logger.debug("not enabling optimization: fusion elemwise")
# def make_composite(inputs, outputs): # def make_composite(inputs, outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论