提交 26889962 authored 作者: Reyhane Askari's avatar Reyhane Askari

removed part of _optcheck_fgraph and called std_fgraph instead

上级 d5123351
......@@ -25,7 +25,7 @@ from theano.gof import (graph, utils, link, ops_with_inner_function)
from theano.gof.link import raise_with_op
from theano.compile.function_module import (
FunctionMaker, Function, infer_reuse_pattern,
SymbolicOutput, Supervisor, std_fgraph)
SymbolicOutput, std_fgraph)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard
......@@ -613,44 +613,10 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
instances already installed.
"""
orig_inputs = [spec.variable for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update]
orig_outputs = [spec.variable for spec in output_specs] + updates
equivalence_tracker = _VariableEquivalenceTracker()
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
features=[equivalence_tracker])
# DestroyHandler may not be needed yet, as there is usually no
# inplace operation in the graph at this stage. DestroyHandler
# will be installed by an optimization after canonicalization,
# before the inplace operations are applied. This results in a big
# speed gain.
#
# If inplace operations are accepted and present, however,
# DestroyHandler will be inserted in the loop below.
if not accept_inplace:
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
raise TypeError("Graph must not contain inplace operations",
node)
else:
# However, if some inplace ops are already in the graph,
# DestroyHandler is needed for the Supervisor below to work correctly.
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
fgraph.attach_feature(gof.DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(Supervisor(
input for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input)))))
for feature in std_fgraph.features:
fgraph.attach_feature(feature())
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
fgraph.attach_feature(equivalence_tracker)
return fgraph, list(map(SymbolicOutput, updates)), equivalence_tracker
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论