提交 fbbd9eff authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: replace inside scan nodes

上级 723dad41
...@@ -851,19 +851,47 @@ def map_variables(fn, graphs, additional_inputs=[]): ...@@ -851,19 +851,47 @@ def map_variables(fn, graphs, additional_inputs=[]):
""" """
from fg import FunctionGraph from fg import FunctionGraph
from opt import TopoOptimizer, local_optimizer from opt import TopoOptimizer, local_optimizer
from theano import clone as the_other_clone
from theano.scan_module.scan_op import Scan
graphs = list(graphs) graphs = list(graphs)
inputs_ = list(set(inputs(graphs) + list(additional_inputs))) inputs_ = list(set(inputs(graphs) + list(additional_inputs)))
# work on a copy of the graph, but ensure it is still in terms of # work on a copy of the graph, but ensure it is still in terms of
# the user's inputs # the user's inputs. we can use clone with copy_inputs=False to
inputs_, graphs = clone(inputs_, graphs, copy_inputs=False) # achieve this, but first we should clone those inputs that are
# cached constants, or FunctionGraph will complain about them.
cached_constants = [x for x in inputs_ if getattr(x, "cached", False)]
copied_constants, _ = clone(cached_constants, [], copy_inputs=True)
graphs = the_other_clone(graphs,
share_inputs=True,
replace=zip(cached_constants, copied_constants))
inputs_ = list(set(inputs_) - set(cached_constants)) + list(copied_constants)
fg = FunctionGraph(inputs_, graphs, clone=False) fg = FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
@local_optimizer(None) @local_optimizer(None)
def local_transform(node): def local_transform(node):
if node in nodes_seen:
return False
# FIXME: replacing inputs won't work because they are not # FIXME: replacing inputs won't work because they are not
# outputs of any Apply node # outputs of any Apply node
if isinstance(node.op, Scan):
# recurse on the inner graph
new_inner_outputs = map_variables(
fn, node.op.outputs,
additional_inputs=additional_inputs)
# reinstantiate the op
new_op = Scan(node.op.inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
# make a new node to replace the old one
new_node = new_op.make_node(*node.inputs)
nodes_seen.add(new_node)
return new_node.outputs
return list(map(fn, node.outputs)) return list(map(fn, node.outputs))
topo_transform = TopoOptimizer(local_transform, 'out_to_in') topo_transform = TopoOptimizer(local_transform, 'out_to_in')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论