提交 c5502dab authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: simplify

上级 fb6adab7
...@@ -816,29 +816,12 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None): ...@@ -816,29 +816,12 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
return memo return memo
def local_replacer(fn):
# FunctionGraph is strict and wants to know its inputs beforehand. we
# don't always know the inputs beforehand, so pass the `fg` into `fn`
# if it accepts it. it can then use fg.add_input() to add any missing
# inputs. this is an internal mechanism currently only used by the
# wrapped `fn` in _map_variables_inner.
def new_fn(graph, fg=None):
import inspect
argspec = inspect.getargspec(fn)
if fg and len(argspec.args) == 2:
new_graph = fn(graph, fg)
else:
new_graph = fn(graph)
return new_graph
return new_fn
def map_variables(replacer, graphs, additional_inputs=[]): def map_variables(replacer, graphs, additional_inputs=[]):
"""Construct new graphs based on 'graphs' with some variables replaced """Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'. according to 'replacer'.
:param replacer: `local_replacer` decorated function that takes a :param replacer: function that takes a variable and returns its
variable and returns its replacement. replacement.
:param graphs: an iterable of graphs in which to replace variables :param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any :param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer` of 'graphs' but possibly used in the graphs returned by `replacer`
...@@ -857,12 +840,10 @@ def map_variables(replacer, graphs, additional_inputs=[]): ...@@ -857,12 +840,10 @@ def map_variables(replacer, graphs, additional_inputs=[]):
ab = a + b ab = a + b
ab.tag.replacement = a * b ab.tag.replacement = a * b
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = ab + c u = ab + c
v, = map_variables(replacer, [u]) v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c # v is now equal to a * b + c
""" """
...@@ -874,11 +855,11 @@ def map_variables(replacer, graphs, additional_inputs=[]): ...@@ -874,11 +855,11 @@ def map_variables(replacer, graphs, additional_inputs=[]):
# wrap replacer to avoid replacing things we just put there. # wrap replacer to avoid replacing things we just put there.
graphs_seen = set() graphs_seen = set()
def wrapped_replacer(graph, fg=None): def wrapped_replacer(graph):
if graph in graphs_seen: if graph in graphs_seen:
return graph return graph
else: else:
new_graph = replacer(graph, fg) new_graph = replacer(graph)
graphs_seen.add(new_graph) graphs_seen.add(new_graph)
return new_graph return new_graph
...@@ -888,9 +869,6 @@ def map_variables(replacer, graphs, additional_inputs=[]): ...@@ -888,9 +869,6 @@ def map_variables(replacer, graphs, additional_inputs=[]):
# perform any desired replacement of input variables. these # perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are # aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node. # not outputs of any Apply node.
# NOTE: we don't need to pass any fgraph into the replacer; we can
# figure out the correct set of inputs from the graph before we
# construct the fgraph.
new_inputs = list(map(wrapped_replacer, inputs_)) new_inputs = list(map(wrapped_replacer, inputs_))
replacements = [(input_, new_input) replacements = [(input_, new_input)
for input_, new_input for input_, new_input
...@@ -944,7 +922,7 @@ def map_variables(replacer, graphs, additional_inputs=[]): ...@@ -944,7 +922,7 @@ def map_variables(replacer, graphs, additional_inputs=[]):
return new_node.outputs return new_node.outputs
else: else:
nodes_seen.add(node) nodes_seen.add(node)
return [wrapped_replacer(graph, fg) for graph in node.outputs] return list(map(wrapped_replacer, node.outputs))
topo_transform = TopoOptimizer(local_transform, 'out_to_in') topo_transform = TopoOptimizer(local_transform, 'out_to_in')
topo_transform.optimize(fg) topo_transform.optimize(fg)
...@@ -974,8 +952,8 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs): ...@@ -974,8 +952,8 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
from itertools import chain from itertools import chain
from theano import gof from theano import gof
def inner_replacer(graph, inner_fg): def inner_replacer(graph):
new_graph = replacer(graph, inner_fg) new_graph = replacer(graph)
other_inputs = [] other_inputs = []
constants = [] constants = []
...@@ -1012,7 +990,10 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs): ...@@ -1012,7 +990,10 @@ def _map_variables_inner(replacer, inner_inputs, outer_inputs, inner_outputs):
outer_to_inner[outer_input] = inner_input outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input) extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input) extra_outer_inputs.append(outer_input)
inner_fg.add_input(inner_input) # the inner FunctionGraph wants to know its inputs
# beforehand, but we don't always know. so add them
# as we discover them.
graph.owner.fgraph.add_input(inner_input)
replacements.extend(outer_to_inner.items()) replacements.extend(outer_to_inner.items())
......
...@@ -11,7 +11,7 @@ from theano import ( ...@@ -11,7 +11,7 @@ from theano import (
from theano.gof.graph import ( from theano.gof.graph import (
Apply, Apply,
as_string, clone, general_toposort, inputs, io_toposort, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable, map_variables, local_replacer) is_same_graph, Variable, map_variables)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.sandbox.cuda.var import ( from theano.sandbox.cuda.var import (
...@@ -163,6 +163,10 @@ class TestClone(X): ...@@ -163,6 +163,10 @@ class TestClone(X):
################# #################
class TestMapVariables(X): class TestMapVariables(X):
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self): def test_leaf(self):
a = tensor.scalar("a") a = tensor.scalar("a")
b = tensor.scalar("b") b = tensor.scalar("b")
...@@ -170,12 +174,8 @@ class TestMapVariables(X): ...@@ -170,12 +174,8 @@ class TestMapVariables(X):
b.tag.replacement = c b.tag.replacement = c
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = a + b u = a + b
v, = map_variables(replacer, [u]) v, = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b] assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c] assert v.owner.inputs == [a, c]
...@@ -192,12 +192,8 @@ class TestMapVariables(X): ...@@ -192,12 +192,8 @@ class TestMapVariables(X):
c = tensor.scalar() c = tensor.scalar()
d = tensor.scalar() d = tensor.scalar()
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
u = OpFromGraph([a, b], [r])(c, d) u = OpFromGraph([a, b], [r])(c, d)
v, = map_variables(replacer, [u]) v, = map_variables(self.replacer, [u])
f = function([c, d], [u, v]) f = function([c, d], [u, v])
for m, n in itertools.combinations(range(10), 2): for m, n in itertools.combinations(range(10), 2):
...@@ -225,16 +221,12 @@ class TestMapVariables(X): ...@@ -225,16 +221,12 @@ class TestMapVariables(X):
r.tag.replacement = z * (a - x) r.tag.replacement = z * (a - x)
return r return r
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(step, sequences=x, s, _ = scan(step, sequences=x,
outputs_info=[numpy.array(0.)]) outputs_info=[numpy.array(0.)])
# ensure z is owned by the outer graph so map_variables() will need to # ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph. # jump through additional hoops to placate FunctionGraph.
t = z * s t = z * s
s2, = map_variables(replacer, [t]) s2, = map_variables(self.replacer, [t])
t2 = z * s2 t2 = z * s2
f = function([x, outer], [t, t2]) f = function([x, outer], [t, t2])
...@@ -251,12 +243,8 @@ class TestMapVariables(X): ...@@ -251,12 +243,8 @@ class TestMapVariables(X):
y.tag.replacement = z y.tag.replacement = z
@local_replacer
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
s, _ = scan(lambda x: x * y, sequences=x) s, _ = scan(lambda x: x * y, sequences=x)
s2, = map_variables(replacer, [s]) s2, = map_variables(self.replacer, [s])
f = function([x, y, z], [s, s2]) f = function([x, y, z], [s, s2])
assert numpy.array_equal( assert numpy.array_equal(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论