提交 fa327997 authored 作者: abergeron's avatar abergeron

Merge pull request #4468 from nouiz/scan_reintroduced

Scan reintroduced
...@@ -328,6 +328,9 @@ class FunctionGraph(utils.object2): ...@@ -328,6 +328,9 @@ class FunctionGraph(utils.object2):
if output.clients or output in self.outputs] if output.clients or output in self.outputs]
# If the apply node is not used and is not an output # If the apply node is not used and is not an output
if not used_or_output: if not used_or_output:
if not hasattr(apply_node.tag, 'removed_by'):
apply_node.tag.removed_by = []
apply_node.tag.removed_by.append(str(reason))
self.apply_nodes.remove(apply_node) self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs) self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason) self.execute_callbacks('on_prune', apply_node, reason)
...@@ -416,6 +419,9 @@ class FunctionGraph(utils.object2): ...@@ -416,6 +419,9 @@ class FunctionGraph(utils.object2):
assert node not in self.apply_nodes assert node not in self.apply_nodes
self.__setup_node__(node) self.__setup_node__(node)
self.apply_nodes.add(node) self.apply_nodes.add(node)
if not hasattr(node.tag, 'imported_by'):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
for output in node.outputs: for output in node.outputs:
self.__setup_r__(output) self.__setup_r__(output)
self.variables.add(output) self.variables.add(output)
......
...@@ -1788,7 +1788,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1788,7 +1788,10 @@ class NavigatorOptimizer(Optimizer):
if replacements is False or replacements is None: if replacements is False or replacements is None:
return False return False
old_vars = node.outputs old_vars = node.outputs
remove = []
if isinstance(replacements, dict): if isinstance(replacements, dict):
if "remove" in replacements:
remove = replacements.pop("remove")
old_vars = list(replacements.keys()) old_vars = list(replacements.keys())
replacements = list(replacements.values()) replacements = list(replacements.values())
elif not isinstance(replacements, (tuple, list)): elif not isinstance(replacements, (tuple, list)):
...@@ -1811,7 +1814,9 @@ class NavigatorOptimizer(Optimizer): ...@@ -1811,7 +1814,9 @@ class NavigatorOptimizer(Optimizer):
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
return False return False
try: try:
fgraph.replace_all_validate(repl_pairs, reason=lopt) fgraph.replace_all_validate_remove(repl_pairs,
reason=lopt,
remove=remove)
return True return True
except Exception as e: except Exception as e:
# This means the replacements were rejected by the fgraph. # This means the replacements were rejected by the fgraph.
......
...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node):
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan(*nw_outer, **dict(return_list=True)) nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return nw_outs return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else: else:
return False return False
...@@ -1964,8 +1964,10 @@ def scan_merge_inouts(node): ...@@ -1964,8 +1964,10 @@ def scan_merge_inouts(node):
outputs = [outputs] outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info) na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node]
else: else:
na = a na = a
remove = []
# Now that the identical external inputs have been merged, we do a new # Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things # loop in order to merge external outputs that compute the same things
...@@ -2069,7 +2071,9 @@ def scan_merge_inouts(node): ...@@ -2069,7 +2071,9 @@ def scan_merge_inouts(node):
seen.append((outer_imm, inner_omm, outer_omm, osl)) seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(outer_omm) new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot na.outer_out_mit_mot = new_outer_out_mit_mot
if remove:
return dict([("remove", remove)] +
list(zip(node.outputs, na.outer_outputs)))
return na.outer_outputs return na.outer_outputs
...@@ -2253,11 +2257,14 @@ class PushOutDot1(gof.Optimizer): ...@@ -2253,11 +2257,14 @@ class PushOutDot1(gof.Optimizer):
# general I do not expect the sequence to run more then once # general I do not expect the sequence to run more then once
scan_eqopt1 = theano.gof.EquilibriumDB() scan_eqopt1 = theano.gof.EquilibriumDB()
scan_seqopt1 = theano.gof.SequenceDB() scan_seqopt1 = theano.gof.SequenceDB()
scan_eqopt2 = theano.gof.EquilibriumDB() scan_eqopt2 = theano.gof.EquilibriumDB()
# scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we
# don't want to reintroduce.
optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan')
# We run before blas opt at 1.7 and specialize 2.0 # We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize? # but after stabilize at 1.5. Should we put it before stabilize?
optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan') optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace', optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeInfer=None, ScanInplaceOptimizer(typeInfer=None,
......
...@@ -1430,9 +1430,6 @@ class ShapeFeature(object): ...@@ -1430,9 +1430,6 @@ class ShapeFeature(object):
class ShapeOptimizer(Optimizer): class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature.""" """Optimizer that serves to add ShapeFeature as an fgraph feature."""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature()) fgraph.attach_feature(ShapeFeature())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论