提交 00e0d806 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove default_updates from local variables of a Scan

This adds `SharedVariable` construction tracking that allows one to determine which variables were created within a specific scope (e.g. within a Python function). With this ability, we're able to determine which shared variable update should and shouldn't be performed within the iterations of a `Scan` node.
上级 110e345f
......@@ -4,7 +4,8 @@ Provide a simple user friendly API to Aesara-managed memory.
"""
import copy
import logging
from contextlib import contextmanager
from typing import List, Optional
import numpy as np
......@@ -14,8 +15,20 @@ from aesara.link.basic import Container
from aesara.link.c.type import generic
_logger = logging.getLogger("aesara.compile.sharedvalue")
__docformat__ = "restructuredtext en"
__SHARED_CONTEXT__: Optional[List[Variable]] = None
@contextmanager
def collect_new_shareds():
r"""Return all the `SharedVariable`\s created within this context manager."""
global __SHARED_CONTEXT__
old_context = __SHARED_CONTEXT__
context = []
try:
__SHARED_CONTEXT__ = context
yield context
finally:
__SHARED_CONTEXT__ = old_context
class SharedVariable(Variable):
......@@ -85,6 +98,11 @@ class SharedVariable(Variable):
allow_downcast=allow_downcast,
)
global __SHARED_CONTEXT__
if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self)
def get_value(self, borrow=False, return_internal_type=False):
"""
Get the non-symbolic value associated with this SharedVariable.
......
......@@ -3,8 +3,8 @@ import warnings
import numpy as np
import aesara.tensor as at
from aesara.compile import SharedVariable
from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.compile.sharedvalue import SharedVariable, collect_new_shareds
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.op import get_test_value
......@@ -861,7 +861,10 @@ def scan(
# when we apply the lambda expression we get a mixture of update rules
# and outputs that needs to be separated
condition, outputs, updates = get_updates_and_outputs(fn(*args))
with collect_new_shareds() as new_shareds:
raw_inner_outputs = fn(*args)
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None:
as_while = True
else:
......@@ -974,13 +977,36 @@ def scan(
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
no_update_shared_inputs = []
for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
if not isinstance(input.variable, SharedVariable):
continue
is_local = input.variable in new_shareds
# We only want to add shared variable updates that were either
# user-specified within the inner-function (e.g. by returning an update
# `dict`) or the `SharedVariable.default_update`s of a shared variable
# created in the inner-function.
if input.update and (is_local or input.variable in updates):
# We need to remove the `default_update`s on the shared
# variables created within the context of the loop function
# (e.g. via use of `RandomStream`); otherwise, they'll get
# picked up during compilation and produce errors when the
# updates include inner-graph variables.
# We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"):
del input.variable.default_update
new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None:
new_var.name = input.variable.name + "_copy"
inner_replacements[input.variable] = new_var
if isinstance(new_var.type, TensorType):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
......@@ -989,6 +1015,7 @@ def scan(
actual_n_steps,
)
)
tensor_update = at.as_tensor_variable(input.update)
sit_sot_inner_outputs.append(tensor_update)
# Note that `pos` is not a negative index. The sign of `pos` is used
......@@ -1000,14 +1027,14 @@ def scan(
# refers to the update rule with index `-1 - pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
inner_replacements[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
inner_replacements[input.variable] = new_var
n_shared_outs += 1
else:
no_update_shared_inputs.append(input)
n_sit_sot = len(sit_sot_inner_inputs)
......@@ -1048,33 +1075,20 @@ def scan(
other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (
isinstance(arg.variable, SharedVariable)
and not arg.update
and arg.variable in non_seqs_set
)
for arg in no_update_shared_inputs
if arg.variable in non_seqs_set
]
else:
other_shared_scan_args = [
arg.variable
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
other_shared_scan_args = [arg.variable for arg in no_update_shared_inputs]
other_shared_inner_args = [
safe_new(arg.variable, "_copy")
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
safe_new(arg.variable, "_copy") for arg in no_update_shared_inputs
]
inner_replacements.update(
dict(zip(other_shared_scan_args, other_shared_inner_args))
)
......
......@@ -42,6 +42,7 @@ from aesara.tensor.math import dot, mean, sigmoid
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy
from aesara.tensor.random import normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape_i, reshape, specify_shape
from aesara.tensor.sharedvar import SharedVariable
......@@ -240,6 +241,54 @@ def scan_nodes_from_fct(fct):
class TestScan:
@pytest.mark.parametrize(
"rng_type",
[
np.random.default_rng,
np.random.RandomState,
],
)
def test_inner_graph_cloning(self, rng_type):
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
inner_inner_rng = shared(rng_type(), name="inner_inner_rng")
y = shared(np.array(1.0, dtype=config.floatX), name="y")
y.default_update = y + 1
z_rng = shared(rng_type(), name="z_rng")
z = normal(0, 1, rng=z_rng, name="z")
z_rng_update = z.owner.outputs[0]
z_rng_update.name = "z_rng_update"
z_rng.default_update = z_rng_update
inner_rng = None
def inner_fn(x):
inner_rng = shared(rng_type(), name="inner_rng")
inner_rng.default_update = inner_inner_rng
inner_inner_rng.default_update = inner_rng
r = normal(x, rng=inner_rng)
return r + y + z, z
out, out_updates = scan(
inner_fn,
outputs_info=[at.as_tensor(0.0, dtype=config.floatX), None],
n_steps=4,
)
assert not hasattr(inner_rng, "default_update")
assert hasattr(inner_inner_rng, "default_update")
assert hasattr(y, "default_update")
assert hasattr(z_rng, "default_update")
out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn()
assert len(set(res)) == 4
assert len(set(z_res)) == 1
@pytest.mark.skipif(
isinstance(get_default_mode(), DebugMode),
reason="This test fails in DebugMode, because it is not yet picklable.",
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论