提交 321c0108 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused aesara.scan.utils.map_variables

上级 97f9ad48
......@@ -20,9 +20,7 @@ from aesara.graph.basic import (
equal_computations,
graph_inputs,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.opt import TopoOptimizer, local_optimizer
from aesara.graph.utils import TestValueError
from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value
from aesara.tensor.subtensor import set_subtensor
......@@ -229,217 +227,6 @@ def traverse(out, x, x_copy, d, visited=None):
return d
def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
tag = "replaceme"
a = aesara.tensor.type.scalar("a")
b = aesara.tensor.type.scalar("b")
c = aesara.tensor.type.scalar("c")
ab = a + b
ab.tag.replacement = a * b
u = ab + c
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
if additional_inputs is None:
additional_inputs = []
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph)
graphs_seen.add(new_graph)
return new_graph
graphs = list(graphs)
inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
new_inputs = [wrapped_replacer(i) for i in inputs_]
replacements = [
(input_, new_input)
for input_, new_input in zip(inputs_, new_inputs)
if new_input is not input_
]
graphs = clone_replace(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))
fg = FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
@local_optimizer(None)
def local_transform(fgraph, node):
if node in nodes_seen:
return False
# importing Scan into module scope would be circular
from aesara.compile.builders import OpFromGraph
from aesara.scan.op import Scan
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
(
new_inner_inputs,
new_outer_inputs,
new_inner_outputs,
) = _map_variables_inner(
wrapped_replacer,
inner_inputs=node.op.inputs,
outer_inputs=node.inputs,
inner_outputs=node.op.outputs,
containing_op=node.op,
)
# reinstantiate the op
if isinstance(node.op, Scan):
new_op = Scan(
new_inner_inputs,
new_inner_outputs,
node.op.info,
node.op.mode,
# FIXME: infer this someday?
typeConstructor=None,
)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(
new_inner_inputs, new_inner_outputs, **node.op.kwargs
)
# make a new node to replace the old one
new_node = new_op.make_node(*new_outer_inputs)
nodes_seen.add(new_node)
return new_node.outputs
else:
nodes_seen.add(node)
replacements = [wrapped_replacer(o) for o in node.outputs]
# Add inputs to replacement graphs as inputs to this `fgraph`
for i in graph_inputs(replacements):
fgraph.add_input(i)
return replacements
topo_transform = TopoOptimizer(local_transform, "out_to_in")
topo_transform.optimize(fg)
new_graphs = fg.outputs
fg.disown()
return new_graphs
def _map_variables_inner(
replacer, inner_inputs, outer_inputs, inner_outputs, containing_op
):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = []
extra_outer_inputs = []
from itertools import chain
from aesara.scan import utils
def inner_replacer(graph):
new_graph = replacer(graph)
other_inputs = []
constants = []
for input_ in graph_inputs([new_graph]):
if isinstance(input_, Variable):
if isinstance(input_, Constant):
constants.append(input_)
else:
other_inputs.append(input_)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs = list(set(other_inputs) - set(outer_to_inner.values()))
# skip further processing if there is nothing to do
if not constants and not foreign_inputs:
return new_graph
replacements = []
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for input_ in constants:
new_input = input_.clone()
new_input.name = f"{new_input.name}_copied"
replacements.append((input_, new_input))
for outer_input in foreign_inputs:
if getattr(outer_input, "update", False):
# when aesara.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
# shared variables with updates inside the inner graph of
# OpFromGraph are not supported at all, so we don't support
# introducing those either.
raise NotImplementedError(
f"Replacement introduces shared variable {outer_input} "
"which has an update associated with it into "
f"the inner graph of {containing_op}. This is not currently "
"supported."
)
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if outer_input not in outer_to_inner.keys():
inner_input = utils.safe_new(outer_input, tag="_copy")
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
replacements.extend(outer_to_inner.items())
(new_graph,) = clone_replace(
[new_graph], share_inputs=True, replace=replacements
)
return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs)
new_inner_inputs = list(chain(inner_inputs, extra_inner_inputs))
new_outer_inputs = list(chain(outer_inputs, extra_outer_inputs))
return new_inner_inputs, new_outer_inputs, new_inner_outputs
def get_updates_and_outputs(ls):
"""
This function tries to recognize the updates OrderedDict, the
......
import itertools
from copy import copy
import numpy as np
......@@ -6,8 +5,7 @@ import pytest
import aesara
from aesara import tensor as at
from aesara.scan.utils import ScanArgs, map_variables
from aesara.tensor.type import scalar, vector
from aesara.scan.utils import ScanArgs
@pytest.fixture(scope="module", autouse=True)
......@@ -16,144 +14,6 @@ def set_aesara_flags():
yield
class TestMapVariables:
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self):
a = scalar("a")
b = scalar("b")
c = scalar("c")
b.tag.replacement = c
u = a + b
(v,) = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
def test_leaf_inside_scan(self):
x = vector("x")
y = scalar("y")
z = scalar("z")
y.tag.replacement = z
s, _ = aesara.scan(lambda x: x * y, sequences=x)
(s2,) = map_variables(self.replacer, [s])
f = aesara.function([x, y, z], [s, s2])
rval = f(x=np.array([1, 2, 3], dtype=np.float32), y=1, z=2)
assert np.array_equal(rval, [[1, 2, 3], [2, 4, 6]])
def test_scan(self):
x = vector("x")
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = scalar("outer")
shared = aesara.shared(np.array(1.0, dtype=aesara.config.floatX), name="shared")
constant = at.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a):
r = a + x
r.tag.replacement = z * (a - x)
return r
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
(s2,) = map_variables(self.replacer, [t])
t2 = z * s2
f = aesara.function([x, outer], [t, t2])
rval = f(x=np.array([1, 2, 3], dtype=np.float32), outer=0.5)
assert np.array_equal(rval, [[1, 3, 6], [-1, -3, -6]])
def test_scan_with_shared_update(self):
x = vector("x")
# counts how many times its value is used
counter = aesara.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
return r
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [s])
def test_scan_with_shared_update2(self):
x = vector("x")
# counts how many times its value is used
counter = aesara.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return r + counter
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [s])
def test_opfromgraph(self):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer = scalar("outer")
shared = aesara.shared(np.array(1.0, dtype=aesara.config.floatX), name="shared")
constant = at.constant(1.0, name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = scalar()
b = scalar()
r = a + b
r.tag.replacement = z * (a - b)
# construct the outer graph
c = scalar()
d = scalar()
u = aesara.compile.builders.OpFromGraph([a, b], [r])(c, d)
t = z * u
(v,) = map_variables(self.replacer, [t])
t2 = z * v
f = aesara.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2):
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [t])
def create_test_hmm():
rng_state = np.random.default_rng(23422)
rng_tt = aesara.shared(rng_state, name="rng", borrow=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论