提交 ef65e52e authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: replace inside OpFromGraph nodes as well

上级 fbbd9eff
......@@ -853,6 +853,7 @@ def map_variables(fn, graphs, additional_inputs=[]):
from opt import TopoOptimizer, local_optimizer
from theano import clone as the_other_clone
from theano.scan_module.scan_op import Scan
from theano.compile import OpFromGraph
graphs = list(graphs)
inputs_ = list(set(inputs(graphs) + list(additional_inputs)))
......@@ -877,17 +878,22 @@ def map_variables(fn, graphs, additional_inputs=[]):
return False
# FIXME: replacing inputs won't work because they are not
# outputs of any Apply node
if isinstance(node.op, Scan):
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
new_inner_outputs = map_variables(
fn, node.op.outputs,
additional_inputs=additional_inputs)
# reinstantiate the op
new_op = Scan(node.op.inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
if isinstance(node.op, Scan):
new_op = Scan(node.op.inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(node.op.inputs,
new_inner_outputs,
**node.op.kwargs)
# make a new node to replace the old one
new_node = new_op.make_node(*node.inputs)
nodes_seen.add(new_node)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论