提交 321c0108 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused aesara.scan.utils.map_variables

上级 97f9ad48
...@@ -20,9 +20,7 @@ from aesara.graph.basic import ( ...@@ -20,9 +20,7 @@ from aesara.graph.basic import (
equal_computations, equal_computations,
graph_inputs, graph_inputs,
) )
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt import TopoOptimizer, local_optimizer
from aesara.graph.utils import TestValueError from aesara.graph.utils import TestValueError
from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value from aesara.tensor.basic import AllocEmpty, get_scalar_constant_value
from aesara.tensor.subtensor import set_subtensor from aesara.tensor.subtensor import set_subtensor
...@@ -229,217 +227,6 @@ def traverse(out, x, x_copy, d, visited=None): ...@@ -229,217 +227,6 @@ def traverse(out, x, x_copy, d, visited=None):
return d return d
def map_variables(replacer, graphs, additional_inputs=None):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
tag = "replaceme"
a = aesara.tensor.type.scalar("a")
b = aesara.tensor.type.scalar("b")
c = aesara.tensor.type.scalar("c")
ab = a + b
ab.tag.replacement = a * b
u = ab + c
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
if additional_inputs is None:
additional_inputs = []
# wrap replacer to avoid replacing things we just put there.
graphs_seen = set()
def wrapped_replacer(graph):
if graph in graphs_seen:
return graph
else:
new_graph = replacer(graph)
graphs_seen.add(new_graph)
return new_graph
graphs = list(graphs)
inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
new_inputs = [wrapped_replacer(i) for i in inputs_]
replacements = [
(input_, new_input)
for input_, new_input in zip(inputs_, new_inputs)
if new_input is not input_
]
graphs = clone_replace(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))
fg = FunctionGraph(inputs_, graphs, clone=False)
nodes_seen = set()
@local_optimizer(None)
def local_transform(fgraph, node):
if node in nodes_seen:
return False
# importing Scan into module scope would be circular
from aesara.compile.builders import OpFromGraph
from aesara.scan.op import Scan
if isinstance(node.op, (Scan, OpFromGraph)):
# recurse on the inner graph
(
new_inner_inputs,
new_outer_inputs,
new_inner_outputs,
) = _map_variables_inner(
wrapped_replacer,
inner_inputs=node.op.inputs,
outer_inputs=node.inputs,
inner_outputs=node.op.outputs,
containing_op=node.op,
)
# reinstantiate the op
if isinstance(node.op, Scan):
new_op = Scan(
new_inner_inputs,
new_inner_outputs,
node.op.info,
node.op.mode,
# FIXME: infer this someday?
typeConstructor=None,
)
elif isinstance(node.op, OpFromGraph):
new_op = OpFromGraph(
new_inner_inputs, new_inner_outputs, **node.op.kwargs
)
# make a new node to replace the old one
new_node = new_op.make_node(*new_outer_inputs)
nodes_seen.add(new_node)
return new_node.outputs
else:
nodes_seen.add(node)
replacements = [wrapped_replacer(o) for o in node.outputs]
# Add inputs to replacement graphs as inputs to this `fgraph`
for i in graph_inputs(replacements):
fgraph.add_input(i)
return replacements
topo_transform = TopoOptimizer(local_transform, "out_to_in")
topo_transform.optimize(fg)
new_graphs = fg.outputs
fg.disown()
return new_graphs
def _map_variables_inner(
replacer, inner_inputs, outer_inputs, inner_outputs, containing_op
):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner = dict(zip(outer_inputs, inner_inputs))
extra_inner_inputs = []
extra_outer_inputs = []
from itertools import chain
from aesara.scan import utils
def inner_replacer(graph):
new_graph = replacer(graph)
other_inputs = []
constants = []
for input_ in graph_inputs([new_graph]):
if isinstance(input_, Variable):
if isinstance(input_, Constant):
constants.append(input_)
else:
other_inputs.append(input_)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs = list(set(other_inputs) - set(outer_to_inner.values()))
# skip further processing if there is nothing to do
if not constants and not foreign_inputs:
return new_graph
replacements = []
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for input_ in constants:
new_input = input_.clone()
new_input.name = f"{new_input.name}_copied"
replacements.append((input_, new_input))
for outer_input in foreign_inputs:
if getattr(outer_input, "update", False):
# when aesara.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
# shared variables with updates inside the inner graph of
# OpFromGraph are not supported at all, so we don't support
# introducing those either.
raise NotImplementedError(
f"Replacement introduces shared variable {outer_input} "
"which has an update associated with it into "
f"the inner graph of {containing_op}. This is not currently "
"supported."
)
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if outer_input not in outer_to_inner.keys():
inner_input = utils.safe_new(outer_input, tag="_copy")
outer_to_inner[outer_input] = inner_input
extra_inner_inputs.append(inner_input)
extra_outer_inputs.append(outer_input)
replacements.extend(outer_to_inner.items())
(new_graph,) = clone_replace(
[new_graph], share_inputs=True, replace=replacements
)
return new_graph
new_inner_outputs = map_variables(inner_replacer, inner_outputs)
new_inner_inputs = list(chain(inner_inputs, extra_inner_inputs))
new_outer_inputs = list(chain(outer_inputs, extra_outer_inputs))
return new_inner_inputs, new_outer_inputs, new_inner_outputs
def get_updates_and_outputs(ls): def get_updates_and_outputs(ls):
""" """
This function tries to recognize the updates OrderedDict, the This function tries to recognize the updates OrderedDict, the
......
import itertools
from copy import copy from copy import copy
import numpy as np import numpy as np
...@@ -6,8 +5,7 @@ import pytest ...@@ -6,8 +5,7 @@ import pytest
import aesara import aesara
from aesara import tensor as at from aesara import tensor as at
from aesara.scan.utils import ScanArgs, map_variables from aesara.scan.utils import ScanArgs
from aesara.tensor.type import scalar, vector
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
...@@ -16,144 +14,6 @@ def set_aesara_flags(): ...@@ -16,144 +14,6 @@ def set_aesara_flags():
yield yield
class TestMapVariables:
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
def test_leaf(self):
a = scalar("a")
b = scalar("b")
c = scalar("c")
b.tag.replacement = c
u = a + b
(v,) = map_variables(self.replacer, [u])
assert u.owner.inputs == [a, b]
assert v.owner.inputs == [a, c]
def test_leaf_inside_scan(self):
x = vector("x")
y = scalar("y")
z = scalar("z")
y.tag.replacement = z
s, _ = aesara.scan(lambda x: x * y, sequences=x)
(s2,) = map_variables(self.replacer, [s])
f = aesara.function([x, y, z], [s, s2])
rval = f(x=np.array([1, 2, 3], dtype=np.float32), y=1, z=2)
assert np.array_equal(rval, [[1, 2, 3], [2, 4, 6]])
def test_scan(self):
x = vector("x")
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer = scalar("outer")
shared = aesara.shared(np.array(1.0, dtype=aesara.config.floatX), name="shared")
constant = at.constant(1, name="constant")
# z will equal 1 so multiplying by it doesn't change any values
z = outer * (shared + constant)
def step(x, a):
r = a + x
r.tag.replacement = z * (a - x)
return r
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t = z * s
(s2,) = map_variables(self.replacer, [t])
t2 = z * s2
f = aesara.function([x, outer], [t, t2])
rval = f(x=np.array([1, 2, 3], dtype=np.float32), outer=0.5)
assert np.array_equal(rval, [[1, 3, 6], [-1, -3, -6]])
def test_scan_with_shared_update(self):
x = vector("x")
# counts how many times its value is used
counter = aesara.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
return r
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [s])
def test_scan_with_shared_update2(self):
x = vector("x")
# counts how many times its value is used
counter = aesara.shared(0, name="shared")
counter.update = counter + 1
def step(x, a):
r = a + x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r.tag.replacement = counter * (a - x)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return r + counter
s, _ = aesara.scan(step, sequences=x, outputs_info=[np.array(0.0)])
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [s])
def test_opfromgraph(self):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer = scalar("outer")
shared = aesara.shared(np.array(1.0, dtype=aesara.config.floatX), name="shared")
constant = at.constant(1.0, name="constant")
z = outer * (shared + constant)
# construct the inner graph
a = scalar()
b = scalar()
r = a + b
r.tag.replacement = z * (a - b)
# construct the outer graph
c = scalar()
d = scalar()
u = aesara.compile.builders.OpFromGraph([a, b], [r])(c, d)
t = z * u
(v,) = map_variables(self.replacer, [t])
t2 = z * v
f = aesara.function([c, d, outer], [t, t2])
for m, n in itertools.combinations(range(10), 2):
assert f(m, n, outer=0.5) == [m + n, m - n]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared.update = shared + 1
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [t])
def create_test_hmm(): def create_test_hmm():
rng_state = np.random.default_rng(23422) rng_state = np.random.default_rng(23422)
rng_tt = aesara.shared(rng_state, name="rng", borrow=True) rng_tt = aesara.shared(rng_state, name="rng", borrow=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论