提交 0f4ced9d authored 作者: Frederic Bastien's avatar Frederic Bastien

added the theano flags op.set_flops that will make convop in profile mode print…

added the theano flags op.set_flops that will make convop in profile mode print the number of flops.
上级 487bfc2b
...@@ -8,6 +8,7 @@ default_={ ...@@ -8,6 +8,7 @@ default_={
'ProfileMode.n_ops_to_print':20, 'ProfileMode.n_ops_to_print':20,
'tensor_opt.local_elemwise_fusion':False, 'tensor_opt.local_elemwise_fusion':False,
'lib.amdlibm':False, 'lib.amdlibm':False,
'op.set_flops':False,#currently used only in ConvOp. The profile mode will print the flops/s for the op.
} }
#default value taked from env variable #default value taked from env variable
...@@ -38,6 +39,8 @@ THEANO_DEBUGMODE_CHECK_PY = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_PY', 1))) ...@@ -38,6 +39,8 @@ THEANO_DEBUGMODE_CHECK_PY = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_PY', 1)))
THEANO_DEBUGMODE_CHECK_FINITE = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_FINITE', 1))) THEANO_DEBUGMODE_CHECK_FINITE = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_FINITE', 1)))
THEANO_DEBUGMODE_CHECK_STRIDES = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 1))) THEANO_DEBUGMODE_CHECK_STRIDES = bool(int(os.getenv('THEANO_DEBUGMODE_CHECK_STRIDES', 1)))
THEANO_FLAGS=os.getenv("THEANO_FLAGS","")
class TheanoConfig(object): class TheanoConfig(object):
"""Return the value for a key after parsing ~/.theano.cfg and """Return the value for a key after parsing ~/.theano.cfg and
the THEANO_FLAGS environment variable. the THEANO_FLAGS environment variable.
...@@ -72,7 +75,7 @@ class TheanoConfig(object): ...@@ -72,7 +75,7 @@ class TheanoConfig(object):
#user config file override the default value #user config file override the default value
self.config.read(['theano.cfg', os.path.expanduser('~/.theano.cfg')]) self.config.read(['theano.cfg', os.path.expanduser('~/.theano.cfg')])
self.env_flags=os.getenv("THEANO_FLAGS","") self.env_flags=THEANO_FLAGS
#The value in the env variable THEANO_FLAGS override the previous value #The value in the env variable THEANO_FLAGS override the previous value
for flag in self.env_flags.split(','): for flag in self.env_flags.split(','):
if not flag: if not flag:
......
import numpy as N import numpy as N
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano import gof, Op, tensor from theano import gof, Op, tensor, config
from theano.printing import Print from theano.printing import Print
def getFilterOutShp(inshp, kshp, (dx,dy)=(1,1), mode='valid'): def getFilterOutShp(inshp, kshp, (dx,dy)=(1,1), mode='valid'):
...@@ -131,6 +131,8 @@ class ConvOp(Op): ...@@ -131,6 +131,8 @@ class ConvOp(Op):
"'valid' mode)")%(self.imshp_logical,self.kshp_logical)) "'valid' mode)")%(self.imshp_logical,self.kshp_logical))
self._rehash() self._rehash()
if config.config.getboolean('op.set_flops'):
self.set_flops()
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) != type(other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论