提交 00e0d806 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove default_updates from local variables of a Scan

This adds `SharedVariable` construction tracking that allows one to determine which variables were created within a specific scope (e.g. within a Python function). With this ability, we're able to determine which shared variable update should and shouldn't be performed within the iterations of a `Scan` node.
上级 110e345f
......@@ -4,7 +4,8 @@ Provide a simple user friendly API to Aesara-managed memory.
"""
import copy
import logging
from contextlib import contextmanager
from typing import List, Optional
import numpy as np
......@@ -14,8 +15,20 @@ from aesara.link.basic import Container
from aesara.link.c.type import generic
_logger = logging.getLogger("aesara.compile.sharedvalue")
__docformat__ = "restructuredtext en"
__SHARED_CONTEXT__: Optional[List[Variable]] = None
@contextmanager
def collect_new_shareds():
r"""Return all the `SharedVariable`\s created within this context manager."""
global __SHARED_CONTEXT__
old_context = __SHARED_CONTEXT__
context = []
try:
__SHARED_CONTEXT__ = context
yield context
finally:
__SHARED_CONTEXT__ = old_context
class SharedVariable(Variable):
......@@ -85,6 +98,11 @@ class SharedVariable(Variable):
allow_downcast=allow_downcast,
)
global __SHARED_CONTEXT__
if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self)
def get_value(self, borrow=False, return_internal_type=False):
"""
Get the non-symbolic value associated with this SharedVariable.
......
......@@ -3,8 +3,8 @@ import warnings
import numpy as np
import aesara.tensor as at
from aesara.compile import SharedVariable
from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.compile.sharedvalue import SharedVariable, collect_new_shareds
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.op import get_test_value
......@@ -861,7 +861,10 @@ def scan(
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition, outputs, updates = get_updates_and_outputs(fn(*args))
with collect_new_shareds() as new_shareds:
raw_inner_outputs = fn(*args)
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None:
as_while = True
else:
......@@ -974,13 +977,36 @@ def scan(
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
no_update_shared_inputs = []
for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
if not isinstance(input.variable, SharedVariable):
continue
is_local = input.variable in new_shareds
# We only want to add shared variable updates that were either
# user-specified within the inner-function (e.g. by returning an update
# `dict`) or the `SharedVariable.default_update`s of a shared variable
# created in the inner-function.
if input.update and (is_local or input.variable in updates):
# We need to remove the `default_update`s on the shared
# variables created within the context of the loop function
# (e.g. via use of `RandomStream`); otherwise, they'll get
# picked up during compilation and produce errors when the
# updates include inner-graph variables.
# We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"):
del input.variable.default_update
new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None:
new_var.name = input.variable.name + "_copy"
inner_replacements[input.variable] = new_var
if isinstance(new_var.type, TensorType):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
......@@ -989,6 +1015,7 @@ def scan(
actual_n_steps,
)
)
tensor_update = at.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Note that `pos` is not a negative index. The sign of `pos` is used
......@@ -1000,14 +1027,14 @@ def scan(
# refers to the update rule with index `-1 - pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
inner_replacements[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
inner_replacements[input.variable] = new_var
n_shared_outs += 1
else:
no_update_shared_inputs.append(input)
n_sit_sot = len(sit_sot_inner_inputs)
......@@ -1048,33 +1075,20 @@ def scan(
other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
else:
other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
other_shared_scan_args = [arg.variable for arg in no_update_shared_inputs]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
safe_new(arg.variable, "_copy") for arg in no_update_shared_inputs
]
inner_replacements.update(
dict(zip(other_shared_scan_args, other_shared_inner_args))
)
......
......@@ -42,6 +42,7 @@ from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy
from aesara.tensor.random import normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, specify_shape
from aesara.tensor.sharedvar import SharedVariable
......@@ -240,6 +241,54 @@ def scan_nodes_from_fct(fct):
class TestScan:
@pytest.mark.parametrize(
"rng_type",
[
np.random.default_rng,
np.random.RandomState,
],
)
def test_inner_graph_cloning(self, rng_type):
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
inner_inner_rng = shared(rng_type(), name="inner_inner_rng")
y = shared(np.array(1.0, dtype=config.floatX), name="y")
y.default_update = y + 1
z_rng = shared(rng_type(), name="z_rng")
z = normal(0, 1, rng=z_rng, name="z")
z_rng_update = z.owner.outputs[0]
z_rng_update.name = "z_rng_update"
z_rng.default_update = z_rng_update
inner_rng = None
def inner_fn(x):
inner_rng = shared(rng_type(), name="inner_rng")
inner_rng.default_update = inner_inner_rng
inner_inner_rng.default_update = inner_rng
r = normal(x, rng=inner_rng)
return r + y + z, z
out, out_updates = scan(
inner_fn,
outputs_info=[at.as_tensor(0.0, dtype=config.floatX), None],
n_steps=4,
)
assert not hasattr(inner_rng, "default_update")
assert hasattr(inner_inner_rng, "default_update")
assert hasattr(y, "default_update")
assert hasattr(z_rng, "default_update")
out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn()
assert len(set(res)) == 4
assert len(set(z_res)) == 1
@pytest.mark.skipif(
isinstance(get_default_mode(), DebugMode),
reason="This test fails in DebugMode, because it is not yet picklable.",
......
......@@ -15,10 +15,7 @@ def set_aesara_flags():
def create_test_hmm():
rng_state = np.random.default_rng(23422)
rng_tt = aesara.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
srng = at.random.RandomStream()
N_tt = at.iscalar("N")
N_tt.tag.test_value = 10
......@@ -33,20 +30,20 @@ def create_test_hmm():
sigmas_tt = at.ones((N_tt,))
sigmas_tt.name = "sigmas"
pi_0_rv = at.random.dirichlet(at.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = at.random.dirichlet(at.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
pi_0_rv = srng.dirichlet(at.ones((M_tt,)), name="pi_0")
Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")
S_0_rv = at.random.categorical(pi_0_rv, rng=rng_tt, name="S_0")
S_0_rv = srng.categorical(pi_0_rv, name="S_0")
def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t, rng):
S_t = at.random.categorical(Gamma_t[S_tm1], rng=rng, name="S_t")
Y_t = at.random.normal(mus_t[S_t], sigma_t, rng=rng, name="Y_t")
def scan_fn(mus_t, sigma_t, S_tm1, Gamma_t):
S_t = srng.categorical(Gamma_t[S_tm1], name="S_t")
Y_t = srng.normal(mus_t[S_t], sigma_t, name="Y_t")
return S_t, Y_t
(S_rv, Y_rv), scan_updates = aesara.scan(
fn=scan_fn,
sequences=[mus_tt, sigmas_tt],
non_sequences=[Gamma_rv, rng_tt],
non_sequences=[Gamma_rv],
outputs_info=[{"initial": S_0_rv, "taps": [-1]}, {}],
strict=True,
name="scan_rv",
......@@ -63,8 +60,6 @@ def create_test_hmm():
S_t = scan_args.inner_out_sit_sot[0]
rng_in = scan_args.inner_out_shared[0]
rng_updates = scan_updates[rng_tt]
rng_updates.name = "rng_updates"
mus_in = Y_rv.owner.inputs[1]
mus_in.name = "mus_in"
sigmas_in = Y_rv.owner.inputs[2]
......@@ -140,10 +135,7 @@ def test_ScanArgs():
def test_ScanArgs_basics_mit_sot():
rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))
rng_tt = aesara.shared(rng_state, name="rng", borrow=True)
rng_tt.tag.is_rng = True
rng_tt.default_update = rng_tt
srng = at.random.RandomStream()
N_tt = at.iscalar("N")
N_tt.tag.test_value = 10
......@@ -158,20 +150,20 @@ def test_ScanArgs_basics_mit_sot():
sigmas_tt = at.ones((N_tt,))
sigmas_tt.name = "sigmas"
pi_0_rv = at.random.dirichlet(at.ones((M_tt,)), rng=rng_tt, name="pi_0")
Gamma_rv = at.random.dirichlet(at.ones((M_tt, M_tt)), rng=rng_tt, name="Gamma")
pi_0_rv = srng.dirichlet(at.ones((M_tt,)), name="pi_0")
Gamma_rv = srng.dirichlet(at.ones((M_tt, M_tt)), name="Gamma")
S_0_rv = at.random.categorical(pi_0_rv, rng=rng_tt, name="S_0")
S_0_rv = srng.categorical(pi_0_rv, name="S_0")
def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t, rng):
S_t = at.random.categorical(Gamma_t[S_tm2], rng=rng, name="S_t")
Y_t = at.random.normal(mus_t[S_tm1], sigma_t, rng=rng, name="Y_t")
def scan_fn(mus_t, sigma_t, S_tm2, S_tm1, Gamma_t):
S_t = srng.categorical(Gamma_t[S_tm2], name="S_t")
Y_t = srng.normal(mus_t[S_tm1], sigma_t, name="Y_t")
return S_t, Y_t
(S_rv, Y_rv), scan_updates = aesara.scan(
fn=scan_fn,
sequences=[mus_tt, sigmas_tt],
non_sequences=[Gamma_rv, rng_tt],
non_sequences=[Gamma_rv],
outputs_info=[{"initial": at.stack([S_0_rv, S_0_rv]), "taps": [-2, -1]}, {}],
strict=True,
name="scan_rv",
......@@ -181,8 +173,6 @@ def test_ScanArgs_basics_mit_sot():
# This `S_rv` outer-output is actually a `Subtensor` of the "real" output
S_rv = S_rv.owner.inputs[0]
S_rv.name = "S_rv"
rng_updates = scan_updates[rng_tt]
rng_updates.name = "rng_updates"
mus_in = Y_rv.owner.inputs[1]
mus_in.name = "mus_in"
sigmas_in = Y_rv.owner.inputs[2]
......@@ -223,9 +213,8 @@ def test_ScanArgs_remove_inner_input():
hmm_model_env["S_rv"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
scan_updates = hmm_model_env["scan_updates"]
# Check `ScanArgs.remove_from_fields` by removing `sigmas[t]` (i.e. the
# inner-graph input)
......@@ -266,9 +255,8 @@ def test_ScanArgs_remove_inner_input():
assert S_in in scan_args_copy.outer_out_sit_sot
assert Gamma_in in scan_args_copy.inner_in_non_seqs
assert Gamma_rv in scan_args_copy.outer_in_non_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert list(scan_updates.values()) == scan_args.outer_out_shared
# The other `Y_rv`-related inputs currently aren't removed, even though
# they're no longer needed.
......@@ -296,9 +284,8 @@ def test_ScanArgs_remove_outer_input():
hmm_model_env["S_rv"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
scan_updates = hmm_model_env["scan_updates"]
# Remove `sigmas` (i.e. the outer-input)
scan_args_copy = copy(scan_args)
......@@ -326,9 +313,8 @@ def test_ScanArgs_remove_outer_input():
assert S_in in scan_args_copy.outer_out_sit_sot
assert Gamma_in in scan_args_copy.inner_in_non_seqs
assert Gamma_rv in scan_args_copy.outer_in_non_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert list(scan_updates.values()) == scan_args.outer_out_shared
def test_ScanArgs_remove_inner_output():
......@@ -345,9 +331,8 @@ def test_ScanArgs_remove_inner_output():
hmm_model_env["S_rv"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
scan_updates = hmm_model_env["scan_updates"]
# Remove `Y_t` (i.e. the inner-output)
scan_args_copy = copy(scan_args)
......@@ -367,9 +352,8 @@ def test_ScanArgs_remove_inner_output():
assert S_in in scan_args_copy.outer_out_sit_sot
assert Gamma_in in scan_args_copy.inner_in_non_seqs
assert Gamma_rv in scan_args_copy.outer_in_non_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert list(scan_updates.values()) == scan_args.outer_out_shared
def test_ScanArgs_remove_outer_output():
......@@ -385,9 +369,8 @@ def test_ScanArgs_remove_outer_output():
Gamma_in = hmm_model_env["Gamma_in"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
scan_updates = hmm_model_env["scan_updates"]
# Remove `Y_rv` (i.e. a nit-sot outer-output)
scan_args_copy = copy(scan_args)
......@@ -407,9 +390,8 @@ def test_ScanArgs_remove_outer_output():
assert S_in in scan_args_copy.outer_out_sit_sot
assert Gamma_in in scan_args_copy.inner_in_non_seqs
assert Gamma_rv in scan_args_copy.outer_in_non_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert list(scan_updates.values()) == scan_args.outer_out_shared
def test_ScanArgs_remove_nonseq_outer_input():
......@@ -427,9 +409,7 @@ def test_ScanArgs_remove_nonseq_outer_input():
Gamma_in = hmm_model_env["Gamma_in"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
# Remove `Gamma` (i.e. a non-sequence outer-input)
scan_args_copy = copy(scan_args)
......@@ -448,9 +428,8 @@ def test_ScanArgs_remove_nonseq_outer_input():
assert sigmas_in in scan_args_copy.outer_in_seqs
assert mus_t in scan_args_copy.inner_in_seqs
assert sigmas_t in scan_args_copy.inner_in_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert rng_in not in scan_args_copy.inner_out_shared
assert not scan_args_copy.outer_out_shared
def test_ScanArgs_remove_nonseq_inner_input():
......@@ -468,9 +447,8 @@ def test_ScanArgs_remove_nonseq_inner_input():
Gamma_in = hmm_model_env["Gamma_in"]
S_in = hmm_model_env["S_in"]
S_t = hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
scan_updates = hmm_model_env["scan_updates"]
# Remove `Gamma` (i.e. a non-sequence inner-input)
scan_args_copy = copy(scan_args)
......@@ -487,9 +465,8 @@ def test_ScanArgs_remove_nonseq_inner_input():
assert sigmas_in in scan_args_copy.outer_in_seqs
assert mus_t in scan_args_copy.inner_in_seqs
assert sigmas_t in scan_args_copy.inner_in_seqs
assert rng_tt in scan_args_copy.outer_in_shared
assert rng_in in scan_args_copy.inner_out_shared
assert rng_updates in scan_args.outer_out_shared
assert rng_in not in scan_args_copy.inner_out_shared
assert list(scan_updates.values()) == scan_args.outer_out_shared
def test_ScanArgs_remove_shared_inner_output():
......@@ -507,19 +484,16 @@ def test_ScanArgs_remove_shared_inner_output():
hmm_model_env["Gamma_in"]
S_in = hmm_model_env["S_in"]
hmm_model_env["S_t"]
rng_tt = hmm_model_env["rng_tt"]
rng_in = hmm_model_env["rng_in"]
rng_updates = hmm_model_env["rng_updates"]
# Remove `rng` (i.e. a shared inner-output)
# Remove a shared inner-output
scan_update = scan_args.inner_out_shared[0]
scan_args_copy = copy(scan_args)
test_v = rng_updates
rm_info = scan_args_copy.remove_from_fields(test_v, rm_dependents=True)
rm_info = scan_args_copy.remove_from_fields(scan_update, rm_dependents=True)
removed_nodes, _ = zip(*rm_info)
assert rng_tt in removed_nodes
assert rng_in in removed_nodes
assert rng_updates in removed_nodes
assert all(v in removed_nodes for v in scan_args.inner_out_shared)
assert Y_rv in removed_nodes
assert S_in in removed_nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论