提交 3a6152d9 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: map inputs

上级 ef65e52e
......@@ -858,16 +858,31 @@ def map_variables(fn, graphs, additional_inputs=[]):
graphs = list(graphs)
inputs_ = list(set(inputs(graphs) + list(additional_inputs)))
# work on a copy of the graph, but ensure it is still in terms of
# the user's inputs. we can use clone with copy_inputs=False to
# achieve this, but first we should clone those inputs that are
# cached constants, or FunctionGraph will complain about them.
# perform any desired replacement of input variables. these aren't
# replaced by the local optimizer approach because they are not
# outputs of any Apply node.
mapped_inputs_ = list(map(fn, inputs_))
replacements = [(input_, mapped_input_)
for input_, mapped_input_
in zip(inputs_, mapped_inputs_)
if mapped_input_ is not input_]
inputs_ = mapped_inputs_
graphs = the_other_clone(graphs,
share_inputs=True,
replace=replacements)
# clone cached constants or FunctionGraph will complain. this has
# to occur in a separate pass from the replacement above because
# both may suggest different replacements for the same variables.
# since the replacements introduced above may involve cached
# constants, the replacement of said constants has to come after.
cached_constants = [x for x in inputs_ if getattr(x, "cached", False)]
copied_constants, _ = clone(cached_constants, [], copy_inputs=True)
inputs_ = list(set(inputs_) - set(cached_constants)) + list(copied_constants)
graphs = the_other_clone(graphs,
share_inputs=True,
replace=zip(cached_constants, copied_constants))
inputs_ = list(set(inputs_) - set(cached_constants)) + list(copied_constants)
fg = FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
......@@ -876,8 +891,6 @@ def map_variables(fn, graphs, additional_inputs=[]):
def local_transform(node):
if node in nodes_seen:
return False
# FIXME: replacing inputs won't work because they are not
# outputs of any Apply node
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
new_inner_outputs = map_variables(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论