提交 4c2ca5f0 authored 作者: carriepl's avatar carriepl

Merge pull request #3314 from cooijmanstim/map_variables

map_variables
......@@ -158,6 +158,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {}
self.profile = None
def add_input(self, input):
if input not in self.inputs:
self.inputs.append(input)
self.__setup_r__(input)
self.variables.add(input)
# Setup a Variable #
def __setup_r__(self, r):
# sets up r so it belongs to this fgraph
......
......@@ -256,6 +256,214 @@ def clone(output,
return outs
def map_variables(replacer, graphs, additional_inputs=[]):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
tag = "replaceme"
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
ab = a + b
ab.tag.replacement = a * b
u = ab + c
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph)
graphs_seen.add(new_graph)
return new_graph
graphs = list(graphs)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
new_inputs = list(map(wrapped_replacer, inputs_))
replacements = [(input_, new_input)
for input_, new_input
in zip(inputs_, new_inputs)
if new_input is not input_]
graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
# clone cached constants or FunctionGraph will complain. this has
# to occur in a separate pass from the replacement above because
# both may suggest different replacements for the same variables.
# since the replacements introduced above may involve cached
# constants, the replacement of said constants has to come after.
cached_constants = [x for x in inputs_ if getattr(x, "cached", False)]
copied_constants = clone(cached_constants, share_inputs=False)
replacements = list(zip(cached_constants, copied_constants))
inputs_ = list(set(inputs_) - set(cached_constants)) + list(copied_constants)
graphs = clone(graphs, share_inputs=True, replace=replacements)
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
@gof.opt.local_optimizer(None)
def local_transform(node):
if node in nodes_seen:
return False
# importing Scan into module scope would be circular
from theano.scan_module.scan_op import Scan
from theano.compile import OpFromGraph
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
(new_inner_inputs,
new_outer_inputs,
new_inner_outputs) = _map_variables_inner(
wrapped_replacer,
inner_inputs=node.op.inputs,
outer_inputs=node.inputs,
inner_outputs=node.op.outputs,
containing_op=node.op)
# reinstantiate the op
if isinstance(node.op, Scan):
new_op = Scan(new_inner_inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(new_inner_inputs,
new_inner_outputs,
**node.op.kwargs)
# make a new node to replace the old one
new_node = new_op.make_node(*new_outer_inputs)
nodes_seen.add(new_node)
return new_node.outputs
else:
nodes_seen.add(node)
return list(map(wrapped_replacer, node.outputs))
topo_transform = gof.opt.TopoOptimizer(local_transform, 'out_to_in')
topo_transform.optimize(fg)
new_graphs = fg.outputs
fg.disown()
return new_graphs
def _map_variables_inner(replacer, inner_inputs, outer_inputs,
inner_outputs, containing_op):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = []
extra_outer_inputs = []
from theano.scan_module import scan_utils
from itertools import chain
from theano import gof
def inner_replacer(graph):
new_graph = replacer(graph)
other_inputs = []
constants = []
for input_ in gof.graph.inputs([new_graph]):
if isinstance(input_, gof.Variable):
if isinstance(input_, gof.Constant):
constants.append(input_)
else:
other_inputs.append(input_)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs = list(set(other_inputs) - set(outer_to_inner.values()))
# skip further processing if there is nothing to do
if not constants and not foreign_inputs:
return new_graph
replacements = []
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for input_ in constants:
new_input = input_.clone()
new_input.name = "%s_copied" % new_input.name
replacements.append((input_, new_input))
for outer_input in foreign_inputs:
if getattr(outer_input, "update", False):
# when theano.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
# shared variables with updates inside the inner graph of
# OpFromGraph are not supported at all, so we don't support
# introducing those either.
raise NotImplementedError(
"Replacement introduces shared variable %s "
"which has an update associated with it into "
"the inner graph of %s. This is not currently "
"supported." % (outer_input, containing_op))
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if outer_input not in outer_to_inner.keys():
inner_input = scan_utils.safe_new(outer_input, tag="_copy")
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
# the inner FunctionGraph wants to know its inputs
# beforehand, but we don't always know. so add them
# as we discover them.
graph.owner.fgraph.add_input(inner_input)
replacements.extend(outer_to_inner.items())
new_graph, = theano.clone([new_graph],
share_inputs=True,
replace=replacements)
return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs)
new_inner_inputs = list(chain(inner_inputs, extra_inner_inputs))
new_outer_inputs = list(chain(outer_inputs, extra_outer_inputs))
return new_inner_inputs, new_outer_inputs, new_inner_outputs
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates OrderedDict, the
......
import itertools
import unittest
import numpy
import theano
from theano.scan_module.scan_utils import equal_computations
from theano import tensor
from theano.scan_module.scan_utils import equal_computations, map_variables
from theano.tensor.type_other import NoneConst
......@@ -11,3 +15,152 @@ def test_equal_compuations():
max_argmax1 = theano.tensor.max_and_argmax(m)
max_argmax2 = theano.tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
#################
# map_variables #
#################
class TestMapVariables(unittest.TestCase):
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self):
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
b.tag.replacement = c
u = a + b
v, = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
def test_leaf_inside_scan(self):
x = tensor.vector('x')
y = tensor.scalar('y')
z = tensor.scalar('z')
y.tag.replacement = z
s, _ = theano.scan(lambda x: x * y, sequences=x)
s2, = map_variables(self.replacer, [s])
f = theano.function([x, y, z], [s, s2])
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), y=1, z=2)
assert numpy.array_equal(rval, [[1, 2, 3], [2, 4, 6]])
def test_scan(self):
x = tensor.vector('x')
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = tensor.scalar("outer")
shared = theano.shared(
numpy.array(1., dtype=theano.config.floatX),
name="shared")
constant = tensor.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a):
r = a + x
r.tag.replacement = z * (a - x)
return r
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
s2, = map_variables(self.replacer, [t])
t2 = z * s2
f = theano.function([x, outer], [t, t2])
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), outer=0.5)
assert numpy.array_equal(rval, [[1, 3, 6], [-1, -3, -6]])
def test_scan_with_shared_update(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
return r
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_scan_with_shared_update2(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return r + counter
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_opfromgraph(self):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer = tensor.scalar("outer")
shared = theano.shared(
numpy.array(1., dtype=theano.config.floatX),
name="shared")
constant = tensor.constant(1., name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = z * (a - b)
# construct the outer graph
c = tensor.scalar()
d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d)
t = z * u
v, = map_variables(self.replacer, [t])
t2 = z * v
f = theano.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2):
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [t])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论