提交 cebafe0d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3403 from nouiz/faster_merge

Faster merge
...@@ -810,6 +810,17 @@ class MergeOptimizer(Optimizer): ...@@ -810,6 +810,17 @@ class MergeOptimizer(Optimizer):
# No need to compare the op again, as it don't change. # No need to compare the op again, as it don't change.
if not inputs_match: if not inputs_match:
continue continue
if hasattr(pairs[0][0].fgraph, 'destroy_handler'):
# If both nodes have clients that destroy
# them, we can't merge them.
clients = pairs[0][0].clients + pairs[0][1].clients
if sum([i in utils.flatten(c.op.destroy_map.values())
for c, i in clients
if c != 'output' and
hasattr(c.op, 'destroy_map')]) > 1:
continue
try: try:
fgraph.replace_all_validate(pairs, 'MergeOptimizer') fgraph.replace_all_validate(pairs, 'MergeOptimizer')
except InconsistencyError: except InconsistencyError:
......
...@@ -3,10 +3,12 @@ from __future__ import print_function ...@@ -3,10 +3,12 @@ from __future__ import print_function
import six.moves.builtins as builtins import six.moves.builtins as builtins
import logging import logging
import time import time
import traceback
import warnings import warnings
import numpy # for numeric_grad import numpy # for numeric_grad
from six import itervalues from six import itervalues
from six.moves import StringIO
import theano import theano
...@@ -515,6 +517,17 @@ def grad(cost, wrt, consider_constant=None, ...@@ -515,6 +517,17 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
# Add the var trace
tr = getattr(var.tag, 'trace', [])
if len(tr) > 0:
message += "\nBacktrace when the node is created:\n"
# Print separate message for each element in the list of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
message += str(sio.getvalue())
raise DisconnectedInputError(message) raise DisconnectedInputError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
......
...@@ -132,14 +132,8 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -132,14 +132,8 @@ def debugprint(obj, depth=-1, print_type=False,
obj) obj)
scan_ops = [] scan_ops = []
for r, p in zip(results_to_print, profile_list): if any([p for p in profile_list if p is not None and p.fct_callcount > 0]):
# Add the parent scan op to the list as well print("""
if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r)
if p is not None:
print("""
Timing Info Timing Info
----------- -----------
--> <time> <% time> - <total time> <% total time>' --> <time> <% time> - <total time> <% total time>'
...@@ -157,6 +151,12 @@ N.B.: ...@@ -157,6 +151,12 @@ N.B.:
to remove when optimizing a graph because their <total time> is very low. to remove when optimizing a graph because their <total time> is very low.
""", file=_file) """, file=_file)
for r, p in zip(results_to_print, profile_list):
# Add the parent scan op to the list as well
if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r)
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
file=_file, order=order, ids=ids, file=_file, order=order, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name, scan_ops=scan_ops, stop_on_name=stop_on_name,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论