提交 429b4469 authored 作者: --global's avatar --global

Add optional parameter update_mapping to FunctionGraph

上级 f762278f
...@@ -159,10 +159,22 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -159,10 +159,22 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
orig_inputs = [spec.variable for spec in input_specs] orig_inputs = [spec.variable for spec in input_specs]
updates = [spec.update for spec in input_specs if spec.update]
# Extract the updates and the mapping between update outputs and
# the updated inputs.
updates = []
update_mapping = {}
out_idx = len(output_specs)
for inp_idx in range(len(input_specs)):
if input_specs[inp_idx].update:
updates.append(input_specs[inp_idx].update)
update_mapping[out_idx] = inp_idx
out_idx += 1
orig_outputs = [spec.variable for spec in output_specs] + updates orig_outputs = [spec.variable for spec in output_specs] + updates
fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs) fgraph = gof.fg.FunctionGraph(orig_inputs, orig_outputs,
update_mapping=update_mapping)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
......
...@@ -109,7 +109,19 @@ class FunctionGraph(utils.object2): ...@@ -109,7 +109,19 @@ class FunctionGraph(utils.object2):
""" """
def __init__(self, inputs, outputs, features=None, clone=True): def __init__(self, inputs, outputs, features=None, clone=True,
update_mapping=None):
"""
Create an FunctionGraph which operates on the subgraph bound by the
inputs and outputs sets.
:param inputs: inputs nodes of the graph, usually declared by the user
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is
useful to remove the constant cache problem.
:param update_mapping: dictionnary describing the mapping between
the inputs with updates and the outputs corresponding to their updates
"""
if clone: if clone:
inputs, outputs = graph.clone(inputs, outputs) inputs, outputs = graph.clone(inputs, outputs)
...@@ -157,6 +169,7 @@ class FunctionGraph(utils.object2): ...@@ -157,6 +169,7 @@ class FunctionGraph(utils.object2):
self.node_locks = {} self.node_locks = {}
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
self.update_mapping = update_mapping
# Setup a Variable # # Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论