提交 fa327997 authored 作者: abergeron's avatar abergeron

Merge pull request #4468 from nouiz/scan_reintroduced

Scan reintroduced
......@@ -328,6 +328,9 @@ class FunctionGraph(utils.object2):
if output.clients or output in self.outputs]
# If the apply node is not used and is not an output
if not used_or_output:
if not hasattr(apply_node.tag, 'removed_by'):
apply_node.tag.removed_by = []
apply_node.tag.removed_by.append(str(reason))
self.apply_nodes.remove(apply_node)
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks('on_prune', apply_node, reason)
......@@ -416,6 +419,9 @@ class FunctionGraph(utils.object2):
assert node not in self.apply_nodes
self.__setup_node__(node)
self.apply_nodes.add(node)
if not hasattr(node.tag, 'imported_by'):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
for output in node.outputs:
self.__setup_r__(output)
self.variables.add(output)
......
......@@ -1788,7 +1788,10 @@ class NavigatorOptimizer(Optimizer):
if replacements is False or replacements is None:
return False
old_vars = node.outputs
remove = []
if isinstance(replacements, dict):
if "remove" in replacements:
remove = replacements.pop("remove")
old_vars = list(replacements.keys())
replacements = list(replacements.values())
elif not isinstance(replacements, (tuple, list)):
......@@ -1811,7 +1814,9 @@ class NavigatorOptimizer(Optimizer):
if len(repl_pairs) == 0:
return False
try:
fgraph.replace_all_validate(repl_pairs, reason=lopt)
fgraph.replace_all_validate_remove(repl_pairs,
reason=lopt,
remove=remove)
return True
except Exception as e:
# This means the replacements were rejected by the fgraph.
......
......@@ -202,7 +202,7 @@ def remove_constants_and_unused_inputs_scan(node):
# DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan(*nw_outer, **dict(return_list=True))
return nw_outs
return dict([("remove", [node])] + list(zip(node.outputs, nw_outs)))
else:
return False
......@@ -1964,8 +1964,10 @@ def scan_merge_inouts(node):
outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node]
else:
na = a
remove = []
# Now that the identical external inputs have been merged, we do a new
# loop in order to merge external outputs that compute the same things
......@@ -2069,7 +2071,9 @@ def scan_merge_inouts(node):
seen.append((outer_imm, inner_omm, outer_omm, osl))
new_outer_out_mit_mot.append(outer_omm)
na.outer_out_mit_mot = new_outer_out_mit_mot
if remove:
return dict([("remove", remove)] +
list(zip(node.outputs, na.outer_outputs)))
return na.outer_outputs
......@@ -2253,11 +2257,14 @@ class PushOutDot1(gof.Optimizer):
# general I do not expect the sequence to run more then once
scan_eqopt1 = theano.gof.EquilibriumDB()
scan_seqopt1 = theano.gof.SequenceDB()
scan_eqopt2 = theano.gof.EquilibriumDB()
# scan_eqopt1 before ShapeOpt at 0.1
# This is needed to don't have ShapeFeature trac old Scan that we
# don't want to reintroduce.
optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan')
# We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize?
optdb.register('scan_eqopt1', scan_eqopt1, .1, 'fast_run', 'scan')
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeInfer=None,
......
......@@ -1430,9 +1430,6 @@ class ShapeFeature(object):
class ShapeOptimizer(Optimizer):
"""Optimizer that serves to add ShapeFeature as an fgraph feature."""
def __init__(self):
Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(ShapeFeature())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论