提交 c5502dab authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: simplify

上级 fb6adab7
......@@ -816,29 +816,12 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
return memo
def local_replacer(fn):
# FunctionGraph is strict and wants to know its inputs beforehand. we
# don't always know the inputs beforehand, so pass the `fg` into `fn`
# if it accepts it. it can then use fg.add_input() to add any missing
# inputs. this is an internal mechanism currently only used by the
# wrapped `fn` in _map_variables_inner.
def new_fn(graph, fg=None):
import inspect
argspec = inspect.getargspec(fn)
if fg and len(argspec.args) == 2:
new_graph = fn(graph, fg)
else:
new_graph = fn(graph)
return new_graph
return new_fn
def map_variables(replacer, graphs, additional_inputs=[]):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: `local_replacer` decorated function that takes a
variable and returns its replacement.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
......@@ -857,12 +840,10 @@ def map_variables(replacer, graphs, additional_inputs=[]):
ab = a + b
ab.tag.replacement = a * b
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = ab + c
v, = map_variables(replacer, [u])
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
......@@ -874,11 +855,11 @@ def map_variables(replacer, graphs, additional_inputs=[]):
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph, fg=None):
def wrapped_replacer(graph):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph, fg)
new_graph = replacer(graph)
graphs_seen.add(new_graph)
return new_graph
......@@ -888,9 +869,6 @@ def map_variables(replacer, graphs, additional_inputs=[]):
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
# NOTE: we don't need to pass any fgraph into the replacer; we can
# figure out the correct set of inputs from the graph before we
# construct the fgraph.
new_inputs = list(map(wrapped_replacer, inputs_))
replacements = [(input_, new_input)
for input_, new_input
......@@ -944,7 +922,7 @@ def map_variables(replacer, graphs, additional_inputs=[]):
return new_node.outputs
else:
nodes_seen.add(node)
return [wrapped_replacer(graph, fg) for graph in node.outputs]
return list(map(wrapped_replacer, node.outputs))
topo_transform = TopoOptimizer(local_transform, 'out_to_in')
topo_transform.optimize(fg)
......@@ -974,8 +952,8 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
from itertools import chain
from theano import gof
def inner_replacer(graph, inner_fg):
new_graph = replacer(graph, inner_fg)
def inner_replacer(graph):
new_graph = replacer(graph)
other_inputs = []
constants = []
......@@ -1012,7 +990,10 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
inner_fg.add_input(inner_input)
# the inner FunctionGraph wants to know its inputs
# beforehand, but we don't always know. so add them
# as we discover them.
graph.owner.fgraph.add_input(inner_input)
replacements.extend(outer_to_inner.items())
......
......@@ -11,7 +11,7 @@ from theano import (
from theano.gof.graph import (
Apply,
as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable, map_variables, local_replacer)
is_same_graph, Variable, map_variables)
from theano.gof.op import Op
from theano.gof.type import Type
from theano.sandbox.cuda.var import (
......@@ -163,6 +163,10 @@ class TestClone(X):
#################
class TestMapVariables(X):
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self):
a = tensor.scalar("a")
b = tensor.scalar("b")
......@@ -170,12 +174,8 @@ class TestMapVariables(X):
b.tag.replacement = c
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = a + b
v, = map_variables(replacer, [u])
v, = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
......@@ -192,12 +192,8 @@ class TestMapVariables(X):
c = tensor.scalar()
d = tensor.scalar()
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = OpFromGraph([a, b], [r])(c, d)
v, = map_variables(replacer, [u])
v, = map_variables(self.replacer, [u])
f = function([c, d], [u, v])
for m, n in itertools.combinations(range(10), 2):
......@@ -225,16 +221,12 @@ class TestMapVariables(X):
r.tag.replacement = z * (a - x)
return r
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
s2, = map_variables(replacer, [t])
s2, = map_variables(self.replacer, [t])
t2 = z * s2
f = function([x, outer], [t, t2])
......@@ -251,12 +243,8 @@ class TestMapVariables(X):
y.tag.replacement = z
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(lambda x: x * y, sequences=x)
s2, = map_variables(replacer, [s])
s2, = map_variables(self.replacer, [s])
f = function([x, y, z], [s, s2])
assert numpy.array_equal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论