提交 fb6adab7 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: detect non_sequences in replacements inside Scan (and…

map_variables: detect non_sequences in replacements inside Scan (and equivalently for OpFromGraph) and properly connect them through inner inputs
上级 d07714f7
......@@ -158,6 +158,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
def add_input(self, input):
if input not in self.inputs:
self.inputs.append(input)
self.__setup_r__(input)
self.variables.add(input)
# Setup a Variable #
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
......
......@@ -816,36 +816,53 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
return memo
def map_variables(fn, graphs, additional_inputs=[]):
"""
Construct new graphs based on 'graphs' with some variables replaced
according to 'fn'.
def local_replacer(fn):
# FunctionGraph is strict and wants to know its inputs beforehand. we
# don't always know the inputs beforehand, so pass the `fg` into `fn`
# if it accepts it. it can then use fg.add_input() to add any missing
# inputs. this is an internal mechanism currently only used by the
# wrapped `fn` in _map_variables_inner.
def new_fn(graph, fg=None):
import inspect
argspec = inspect.getargspec(fn)
if fg and len(argspec.args) == 2:
new_graph = fn(graph, fg)
else:
new_graph = fn(graph)
return new_graph
return new_fn
:param fn: function that takes a variable and returns its replacement
def map_variables(replacer, graphs, additional_inputs=[]):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: `local_replacer` decorated function that takes a
variable and returns its replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by 'fn'
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
import theano.tensor
tag = "replaceme"
a = theano.tensor.scalar("a")
b = theano.tensor.scalar("b")
c = theano.tensor.scalar("c")
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
ab = a + b
setattr(ab.tag, tag, True)
ab.tag.replacement = a * b
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = ab + c
v, = map_variables(
lambda x: a * b if getattr(x.tag, tag, False) else x,
[u])
v, = map_variables(replacer, [u])
# v is now equal to a * b + c
"""
......@@ -855,19 +872,32 @@ def map_variables(fn, graphs, additional_inputs=[]):
from theano.scan_module.scan_op import Scan
from theano.compile import OpFromGraph
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph, fg=None):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph, fg)
graphs_seen.add(new_graph)
return new_graph
graphs = list(graphs)
inputs_ = list(set(inputs(graphs) + list(additional_inputs)))
# perform any desired replacement of input variables. these aren't
# replaced by the local optimizer approach because they are not
# outputs of any Apply node.
mapped_inputs_ = list(map(fn, inputs_))
replacements = [(input_, mapped_input_)
for input_, mapped_input_
in zip(inputs_, mapped_inputs_)
if mapped_input_ is not input_]
inputs_ = mapped_inputs_
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
# NOTE: we don't need to pass any fgraph into the replacer; we can
# figure out the correct set of inputs from the graph before we
# construct the fgraph.
new_inputs = list(map(wrapped_replacer, inputs_))
replacements = [(input_, new_input)
for input_, new_input
in zip(inputs_, new_inputs)
if new_input is not input_]
graphs = the_other_clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(inputs(graphs) + list(additional_inputs)))
# clone cached constants or FunctionGraph will complain. this has
# to occur in a separate pass from the replacement above because
......@@ -890,34 +920,114 @@ def map_variables(fn, graphs, additional_inputs=[]):
return False
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
new_inner_outputs = map_variables(
fn, node.op.outputs,
additional_inputs=additional_inputs)
(new_inner_inputs,
new_outer_inputs,
new_inner_outputs) = _map_variables_inner(
wrapped_replacer,
inner_inputs=node.op.inputs,
outer_inputs=node.inputs,
inner_outputs=node.op.outputs)
# reinstantiate the op
if isinstance(node.op, Scan):
new_op = Scan(node.op.inputs,
new_op = Scan(new_inner_inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(node.op.inputs,
new_op = OpFromGraph(new_inner_inputs,
new_inner_outputs,
**node.op.kwargs)
# make a new node to replace the old one
new_node = new_op.make_node(*node.inputs)
new_node = new_op.make_node(*new_outer_inputs)
nodes_seen.add(new_node)
return new_node.outputs
return list(map(fn, node.outputs))
else:
nodes_seen.add(node)
return [wrapped_replacer(graph, fg) for graph in node.outputs]
topo_transform = TopoOptimizer(local_transform, 'out_to_in')
topo_transform.optimize(fg)
new_graphs = fg.outputs
theano.printing.debugprint(new_graphs)
fg.disown()
return new_graphs
def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = []
extra_outer_inputs = []
from theano.scan_module import scan_utils
from itertools import chain
from theano import gof
def inner_replacer(graph, inner_fg):
new_graph = replacer(graph, inner_fg)
other_inputs = []
constants = []
for input_ in gof.graph.inputs([new_graph]):
if isinstance(input_, gof.Variable):
if isinstance(input_, Constant):
constants.append(input_)
else:
other_inputs.append(input_)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs = list(set(other_inputs) - set(outer_to_inner.values()))
# skip further processing if there is nothing to do
#if not constants and not foreign_inputs:
# return new_graph
replacements = []
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for input_ in constants:
new_input = input_.clone()
new_input.name = "%s_copiedd" % new_input.name
replacements.append((input_, new_input))
for outer_input in foreign_inputs:
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if outer_input not in outer_to_inner.keys():
inner_input = scan_utils.safe_new(outer_input, tag="_copy")
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
inner_fg.add_input(inner_input)
replacements.extend(outer_to_inner.items())
new_graph, = theano.clone([new_graph],
share_inputs=True,
replace=replacements)
return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs)
new_inner_inputs = list(chain(inner_inputs, extra_inner_inputs))
new_outer_inputs = list(chain(outer_inputs, extra_outer_inputs))
return new_inner_inputs, new_outer_inputs, new_inner_outputs
def general_toposort(r_out, deps, debug_print=False,
compute_deps_cache=None, deps_cache=None):
"""
......
......@@ -11,7 +11,7 @@ from theano import (
from theano.gof.graph import (
Apply,
as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable, map_variables)
is_same_graph, Variable, map_variables, local_replacer)
from theano.gof.op import Op
from theano.gof.type import Type
from theano.sandbox.cuda.var import (
......@@ -170,10 +170,12 @@ class TestMapVariables(X):
b.tag.replacement = c
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = a + b
v, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[u])
v, = map_variables(replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
......@@ -190,10 +192,12 @@ class TestMapVariables(X):
c = tensor.scalar()
d = tensor.scalar()
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = OpFromGraph([a, b], [r])(c, d)
v, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[u])
v, = map_variables(replacer, [u])
f = function([c, d], [u, v])
for m, n in itertools.combinations(range(10), 2):
......@@ -221,6 +225,7 @@ class TestMapVariables(X):
r.tag.replacement = z * (a - x)
return r
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
......@@ -246,10 +251,12 @@ class TestMapVariables(X):
y.tag.replacement = z
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(lambda x: x * y, sequences=x)
s2, = map_variables(
lambda x: getattr(x.tag, "replacement", x),
[s])
s2, = map_variables(replacer, [s])
f = function([x, y, z], [s, s2])
assert numpy.array_equal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论