提交 4c2ca5f0 authored 作者: carriepl's avatar carriepl

Merge pull request #3314 from cooijmanstim/map_variables

map_variables
...@@ -158,6 +158,12 @@ class FunctionGraph(utils.object2): ...@@ -158,6 +158,12 @@ class FunctionGraph(utils.object2):
self.variable_locks = {} self.variable_locks = {}
self.profile = None self.profile = None
def add_input(self, input):
if input not in self.inputs:
self.inputs.append(input)
self.__setup_r__(input)
self.variables.add(input)
# Setup a Variable # # Setup a Variable #
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
......
...@@ -256,6 +256,214 @@ def clone(output, ...@@ -256,6 +256,214 @@ def clone(output,
return outs return outs
def map_variables(replacer, graphs, additional_inputs=[]):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
tag = "replaceme"
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
ab = a + b
ab.tag.replacement = a * b
u = ab + c
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph)
graphs_seen.add(new_graph)
return new_graph
graphs = list(graphs)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
new_inputs = list(map(wrapped_replacer, inputs_))
replacements = [(input_, new_input)
for input_, new_input
in zip(inputs_, new_inputs)
if new_input is not input_]
graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
# clone cached constants or FunctionGraph will complain. this has
# to occur in a separate pass from the replacement above because
# both may suggest different replacements for the same variables.
# since the replacements introduced above may involve cached
# constants, the replacement of said constants has to come after.
cached_constants = [x for x in inputs_ if getattr(x, "cached", False)]
copied_constants = clone(cached_constants, share_inputs=False)
replacements = list(zip(cached_constants, copied_constants))
inputs_ = list(set(inputs_) - set(cached_constants)) + list(copied_constants)
graphs = clone(graphs, share_inputs=True, replace=replacements)
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
@gof.opt.local_optimizer(None)
def local_transform(node):
if node in nodes_seen:
return False
# importing Scan into module scope would be circular
from theano.scan_module.scan_op import Scan
from theano.compile import OpFromGraph
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
(new_inner_inputs,
new_outer_inputs,
new_inner_outputs) = _map_variables_inner(
wrapped_replacer,
inner_inputs=node.op.inputs,
outer_inputs=node.inputs,
inner_outputs=node.op.outputs,
containing_op=node.op)
# reinstantiate the op
if isinstance(node.op, Scan):
new_op = Scan(new_inner_inputs,
new_inner_outputs,
node.op.info,
# FIXME: infer this someday?
typeConstructor=None)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(new_inner_inputs,
new_inner_outputs,
**node.op.kwargs)
# make a new node to replace the old one
new_node = new_op.make_node(*new_outer_inputs)
nodes_seen.add(new_node)
return new_node.outputs
else:
nodes_seen.add(node)
return list(map(wrapped_replacer, node.outputs))
topo_transform = gof.opt.TopoOptimizer(local_transform, 'out_to_in')
topo_transform.optimize(fg)
new_graphs = fg.outputs
fg.disown()
return new_graphs
def _map_variables_inner(replacer, inner_inputs, outer_inputs,
inner_outputs, containing_op):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = []
extra_outer_inputs = []
from theano.scan_module import scan_utils
from itertools import chain
from theano import gof
def inner_replacer(graph):
new_graph = replacer(graph)
other_inputs = []
constants = []
for input_ in gof.graph.inputs([new_graph]):
if isinstance(input_, gof.Variable):
if isinstance(input_, gof.Constant):
constants.append(input_)
else:
other_inputs.append(input_)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs = list(set(other_inputs) - set(outer_to_inner.values()))
# skip further processing if there is nothing to do
if not constants and not foreign_inputs:
return new_graph
replacements = []
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for input_ in constants:
new_input = input_.clone()
new_input.name = "%s_copied" % new_input.name
replacements.append((input_, new_input))
for outer_input in foreign_inputs:
if getattr(outer_input, "update", False):
# when theano.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
# shared variables with updates inside the inner graph of
# OpFromGraph are not supported at all, so we don't support
# introducing those either.
raise NotImplementedError(
"Replacement introduces shared variable %s "
"which has an update associated with it into "
"the inner graph of %s. This is not currently "
"supported." % (outer_input, containing_op))
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if outer_input not in outer_to_inner.keys():
inner_input = scan_utils.safe_new(outer_input, tag="_copy")
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
# the inner FunctionGraph wants to know its inputs
# beforehand, but we don't always know. so add them
# as we discover them.
graph.owner.fgraph.add_input(inner_input)
replacements.extend(outer_to_inner.items())
new_graph, = theano.clone([new_graph],
share_inputs=True,
replace=replacements)
return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs)
new_inner_inputs = list(chain(inner_inputs, extra_inner_inputs))
new_outer_inputs = list(chain(outer_inputs, extra_outer_inputs))
return new_inner_inputs, new_outer_inputs, new_inner_outputs
def get_updates_and_outputs(ls): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates OrderedDict, the This function tries to recognize the updates OrderedDict, the
......
import itertools
import unittest
import numpy
import theano import theano
from theano.scan_module.scan_utils import equal_computations from theano import tensor
from theano.scan_module.scan_utils import equal_computations, map_variables
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
...@@ -11,3 +15,152 @@ def test_equal_compuations(): ...@@ -11,3 +15,152 @@ def test_equal_compuations():
max_argmax1 = theano.tensor.max_and_argmax(m) max_argmax1 = theano.tensor.max_and_argmax(m)
max_argmax2 = theano.tensor.max_and_argmax(m) max_argmax2 = theano.tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2) assert equal_computations(max_argmax1, max_argmax2)
#################
# map_variables #
#################
class TestMapVariables(unittest.TestCase):
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self):
a = tensor.scalar("a")
b = tensor.scalar("b")
c = tensor.scalar("c")
b.tag.replacement = c
u = a + b
v, = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
def test_leaf_inside_scan(self):
x = tensor.vector('x')
y = tensor.scalar('y')
z = tensor.scalar('z')
y.tag.replacement = z
s, _ = theano.scan(lambda x: x * y, sequences=x)
s2, = map_variables(self.replacer, [s])
f = theano.function([x, y, z], [s, s2])
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), y=1, z=2)
assert numpy.array_equal(rval, [[1, 2, 3], [2, 4, 6]])
def test_scan(self):
x = tensor.vector('x')
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = tensor.scalar("outer")
shared = theano.shared(
numpy.array(1., dtype=theano.config.floatX),
name="shared")
constant = tensor.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a):
r = a + x
r.tag.replacement = z * (a - x)
return r
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
s2, = map_variables(self.replacer, [t])
t2 = z * s2
f = theano.function([x, outer], [t, t2])
rval = f(x=numpy.array([1, 2, 3], dtype=numpy.float32), outer=0.5)
assert numpy.array_equal(rval, [[1, 3, 6], [-1, -3, -6]])
def test_scan_with_shared_update(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
return r
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_scan_with_shared_update2(self):
x = tensor.vector('x')
# counts how many times its value is used
counter = theano.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return r + counter
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_opfromgraph(self):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer = tensor.scalar("outer")
shared = theano.shared(
numpy.array(1., dtype=theano.config.floatX),
name="shared")
constant = tensor.constant(1., name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = tensor.scalar()
b = tensor.scalar()
r = a + b
r.tag.replacement = z * (a - b)
# construct the outer graph
c = tensor.scalar()
d = tensor.scalar()
u = theano.OpFromGraph([a, b], [r])(c, d)
t = z * u
v, = map_variables(self.replacer, [t])
t2 = z * v
f = theano.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2):
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [t])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论