提交 9e501182 authored 作者: Frederic's avatar Frederic

Add the printing of replace_all_validate with the flag time_seq_optimizer.

上级 f7ca8364
...@@ -154,6 +154,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -154,6 +154,7 @@ class SeqOptimizer(Optimizer, list):
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
l = [] l = []
replace_all_validate_before = env.replace_all_validate_time
nb_node_before = len(env.nodes) nb_node_before = len(env.nodes)
for optimizer in self: for optimizer in self:
try: try:
...@@ -174,7 +175,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -174,7 +175,8 @@ class SeqOptimizer(Optimizer, list):
if hasattr(self,"name"): print self.name, if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__, elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(env.nodes)) print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(env.nodes))
print " time %.3fs for replace_all_validate " % (
env.replace_all_validate_time - replace_all_validate_before)
ll=[] ll=[]
for opt in self: for opt in self:
if hasattr(opt,"__name__"): if hasattr(opt,"__name__"):
......
import sys import sys
import time
from theano.gof.python25 import partial from theano.gof.python25 import partial
...@@ -85,23 +86,27 @@ class ReplaceValidate(History, Validator): ...@@ -85,23 +86,27 @@ class ReplaceValidate(History, Validator):
def on_attach(self, env): def on_attach(self, env):
History.on_attach(self, env) History.on_attach(self, env)
Validator.on_attach(self, env) Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'): for attr in ('replace_validate', 'replace_all_validate',
'replace_all_validate_time'):
if hasattr(env, attr): if hasattr(env, attr):
raise AlreadyThere("ReplaceValidate feature is already present" raise AlreadyThere("ReplaceValidate feature is already present"
" or in conflict with another plugin.") " or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
env.replace_all_validate_time = 0
def on_detach(self, env): def on_detach(self, env):
History.on_detach(self, env) History.on_detach(self, env)
Validator.on_detach(self, env) Validator.on_detach(self, env)
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
del env.replace_all_validate_time
def replace_validate(self, env, r, new_r, reason=None): def replace_validate(self, env, r, new_r, reason=None):
self.replace_all_validate(env, [(r, new_r)], reason=reason) self.replace_all_validate(env, [(r, new_r)], reason=reason)
def replace_all_validate(self, env, replacements, reason=None): def replace_all_validate(self, env, replacements, reason=None):
t0 = time.time()
chk = env.checkpoint() chk = env.checkpoint()
for r, new_r in replacements: for r, new_r in replacements:
try: try:
...@@ -121,6 +126,8 @@ class ReplaceValidate(History, Validator): ...@@ -121,6 +126,8 @@ class ReplaceValidate(History, Validator):
except Exception, e: except Exception, e:
env.revert(chk) env.revert(chk)
raise raise
t1 = time.time()
env.replace_all_validate_time += t1 - t0
class NodeFinder(dict, Bookkeeper): class NodeFinder(dict, Bookkeeper):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论