提交 9e2f5309 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add DestroyHandler if there are inplace ops

Even in DebugMode, even before the inplace phase.
上级 7b51cfba
...@@ -677,8 +677,13 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -677,8 +677,13 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs) inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph = gof.fg.FunctionGraph(inputs, outputs, fgraph = gof.fg.FunctionGraph(inputs, outputs,
#DestroyHandler is not needed because it is actually installed by an optimization # DestroyHandler may not be needed yet, as there is usually no
# after canonicalization. This variables in a big speed gain. # inplace operation in the graph at this stage. DestroyHandler
# will be installed by an optimization after canonicalization,
# before the inplace operations are applied.
# This results in a big speed gain.
# If inplace operations are accepted and present, however,
# DestroyHandler will be inserted in the loop below.
#features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)]) #features=[equivalence_tracker, gof.DestroyHandler(do_imports_on_attach=False)])
features=[equivalence_tracker]) features=[equivalence_tracker])
...@@ -687,6 +692,13 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -687,6 +692,13 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
raise TypeError("Graph must not contain inplace operations", raise TypeError("Graph must not contain inplace operations",
node) node)
else:
# However, if some inplace ops are already in the graph,
# DestroyHandler is needed for the Supervisor below to work correctly.
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
fgraph.attach_feature(gof.DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, inputs) fgraph.attach_feature(Supervisor(input for spec, input in zip(input_specs, inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论