提交 a552d638 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the SequenceOptimizer still have its profile printed when there is an error.

上级 acd4a90f
......@@ -1466,6 +1466,10 @@ class FunctionMaker(object):
theano.config.traceback.limit = theano.config.traceback.compile_limit
start_optimizer = time.time()
# In case there is an error during optimization.
optimizer_profile = None
opt_time = None
# now optimize the graph
if theano.config.cache_optimizations:
optimizer_profile = self.optimize_graph_with_cache(
......@@ -1475,8 +1479,23 @@ class FunctionMaker(object):
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
theano.config.traceback.limit = limit_orig
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
theano.compile.profiling.total_graph_opt_time += opt_time
if profile:
if (optimizer_profile is None and
hasattr(optimizer, 'pre_profile')):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time
if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer,
......@@ -1485,13 +1504,6 @@ class FunctionMaker(object):
warnings.warn((
"config.profile_optimizer requires config.profile to "
" be set to True as well"), stacklevel=3)
_logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
theano.config.traceback.limit = limit_orig
# initialize the linker
if not hasattr(linker, 'accept'):
......@@ -1783,7 +1795,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
if isinstance(mode, (list, tuple)): # "mode comparison" semantics
raise Exception("We do not support the passing of multiple modes")
else:
try:
Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs,
outputs,
......@@ -1793,11 +1805,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
on_unused_input=on_unused_input,
output_keys=output_keys).create(
defaults)
t2 = time.time()
if profile:
profile.compile_time += t2 - t1
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
finally:
t2 = time.time()
if profile:
profile.compile_time += t2 - t1
# TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
fn.name = name
fn.maker.fgraph.name = name
......
......@@ -228,45 +228,53 @@ class SeqOptimizer(Optimizer, list):
nb_node_before = len(fgraph.apply_nodes)
sub_profs = []
nb_nodes = []
for optimizer in self:
try:
nb_nodes_before = len(fgraph.apply_nodes)
t0 = time.time()
sub_prof = optimizer.optimize(fgraph)
l.append(float(time.time() - t0))
sub_profs.append(sub_prof)
nb_nodes.append((nb_nodes_before,
len(fgraph.apply_nodes)))
if fgraph.profile:
sub_validate_time.append(fgraph.profile.validate_time)
except AssertionError:
# do not catch Assertion failures
raise
except Exception as e:
if self.failure_callback:
self.failure_callback(e, self, optimizer)
continue
else:
raise
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callbacks_time = {}
for k, v in iteritems(fgraph.execute_callbacks_times):
if k in callbacks_before:
t = v - callbacks_before[k]
if t > 0:
callbacks_time[k] = t
else:
callbacks_time[k] = v
else:
validate_time = None
callbacks_time = {}
self.pre_profile = (
self, l, -1, -1, nb_node_before,
-1, sub_profs, sub_validate_time,
nb_nodes, {})
try:
for optimizer in self:
try:
nb_nodes_before = len(fgraph.apply_nodes)
t0 = time.time()
sub_prof = optimizer.optimize(fgraph)
l.append(float(time.time() - t0))
sub_profs.append(sub_prof)
nb_nodes.append((nb_nodes_before,
len(fgraph.apply_nodes)))
if fgraph.profile:
sub_validate_time.append(fgraph.profile.validate_time)
except AssertionError:
# do not catch Assertion failures
raise
except Exception as e:
if self.failure_callback:
self.failure_callback(e, self, optimizer)
continue
else:
raise
finally:
callback_time = fgraph.execute_callbacks_time - callback_before
return (self, l, validate_time, callback_time, nb_node_before,
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
callbacks_time = {}
for k, v in iteritems(fgraph.execute_callbacks_times):
if k in callbacks_before:
t = v - callbacks_before[k]
if t > 0:
callbacks_time[k] = t
else:
callbacks_time[k] = v
else:
validate_time = None
callbacks_time = {}
callback_time = fgraph.execute_callbacks_time - callback_before
self.pre_profile = (
self, l, validate_time, callback_time, nb_node_before,
len(fgraph.apply_nodes), sub_profs, sub_validate_time,
nb_nodes, callbacks_time)
return self.pre_profile
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论