提交 79fd771a authored 作者: Frederic's avatar Frederic

print the timming of env.validate and not env.replace_all_validate and print in with profile=True.

上级 f0fac6fb
...@@ -120,6 +120,11 @@ class ProfileStats(object): ...@@ -120,6 +120,11 @@ class ProfileStats(object):
optimizer_time = 0.0 optimizer_time = 0.0
# time spent optimizing graph (FunctionMaker.__init__) # time spent optimizing graph (FunctionMaker.__init__)
validate_time = 0.0
# time spent in env.validate
# This is a subset of optimizer_time that is dominated by toposort()
# when the destorymap feature is included.
linker_time = 0.0 linker_time = 0.0
# time spent linking graph (FunctionMaker.create) # time spent linking graph (FunctionMaker.create)
...@@ -392,11 +397,15 @@ class ProfileStats(object): ...@@ -392,11 +397,15 @@ class ProfileStats(object):
local_time, 100*local_time / self.fct_call_time) local_time, 100*local_time / self.fct_call_time)
print >> file, ' Total compile time: %es' % self.compile_time print >> file, ' Total compile time: %es' % self.compile_time
print >> file, ' Theano Optimizer time: %es' % self.optimizer_time print >> file, ' Theano Optimizer time: %es' % self.optimizer_time
print >> file, ' Theano validate time: %es' % self.validate_time
print >> file, (' Theano Linker time (includes C,' print >> file, (' Theano Linker time (includes C,'
' CUDA code generation/compiling): %es' % ' CUDA code generation/compiling): %es' %
self.linker_time) self.linker_time)
print >> file, '' print >> file, ''
# The validation time is a subset of optimizer_time
assert self.validate_time < self.optimizer_time
def summary(self, file=sys.stderr, n_ops_to_print=20, def summary(self, file=sys.stderr, n_ops_to_print=20,
n_applies_to_print=20): n_applies_to_print=20):
self.summary_function(file) self.summary_function(file)
......
...@@ -154,7 +154,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -154,7 +154,7 @@ class SeqOptimizer(Optimizer, list):
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l = [] l = []
replace_all_validate_before = env.replace_all_validate_time validate_before = env.validate_time
nb_node_before = len(env.nodes) nb_node_before = len(env.nodes)
for optimizer in self: for optimizer in self:
try: try:
...@@ -175,8 +175,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -175,8 +175,8 @@ class SeqOptimizer(Optimizer, list):
if hasattr(self,"name"): print self.name, if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__, elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(env.nodes)) print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(env.nodes))
print " time %.3fs for replace_all_validate " % ( print " time %.3fs for validate " % (
env.replace_all_validate_time - replace_all_validate_before) env.validate_time - validate_before)
ll=[] ll=[]
for opt in self: for opt in self:
if hasattr(opt,"__name__"): if hasattr(opt,"__name__"):
......
...@@ -63,10 +63,20 @@ class History: ...@@ -63,10 +63,20 @@ class History:
class Validator: class Validator:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'validate'): for attr in ('validate', 'validate_time'):
raise AlreadyThere("Validator feature is already present or in" if hasattr(env, attr):
" conflict with another plugin.") raise AlreadyThere("Validator feature is already present or in"
env.validate = lambda: env.execute_callbacks('validate') " conflict with another plugin.")
def validate():
t0 = time.time()
ret = env.execute_callbacks('validate')
t1 = time.time()
env.validate_time += t1 - t0
return ret
env.validate = validate
env.validate_time = 0
def consistent(): def consistent():
try: try:
...@@ -79,6 +89,7 @@ class Validator: ...@@ -79,6 +89,7 @@ class Validator:
def on_detach(self, env): def on_detach(self, env):
del env.validate del env.validate
del env.consistent del env.consistent
del env.validate_time
class ReplaceValidate(History, Validator): class ReplaceValidate(History, Validator):
...@@ -86,27 +97,23 @@ class ReplaceValidate(History, Validator): ...@@ -86,27 +97,23 @@ class ReplaceValidate(History, Validator):
def on_attach(self, env): def on_attach(self, env):
History.on_attach(self, env) History.on_attach(self, env)
Validator.on_attach(self, env) Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate', for attr in ('replace_validate', 'replace_all_validate'):
'replace_all_validate_time'):
if hasattr(env, attr): if hasattr(env, attr):
raise AlreadyThere("ReplaceValidate feature is already present" raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_time = 0
def on_detach(self, env): def on_detach(self, env):
History.on_detach(self, env) History.on_detach(self, env)
Validator.on_detach(self, env) Validator.on_detach(self, env)
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
del env.replace_all_validate_time
def replace_validate(self, env, r, new_r, reason=None): def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason) self.replace_all_validate(env, [(r, new_r)], reason=reason)
def replace_all_validate(self, env, replacements, reason=None): def replace_all_validate(self, env, replacements, reason=None):
t0 = time.time()
chk = env.checkpoint() chk = env.checkpoint()
for r, new_r in replacements: for r, new_r in replacements:
try: try:
...@@ -126,8 +133,6 @@ class ReplaceValidate(History, Validator): ...@@ -126,8 +133,6 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
env.revert(chk) env.revert(chk)
raise raise
t1 = time.time()
env.replace_all_validate_time += t1 - t0
class NodeFinder(dict, Bookkeeper): class NodeFinder(dict, Bookkeeper):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论