提交 33eaccac authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused/deprecated variable attributes added by RandomStream

上级 75b4dd18
......@@ -30,12 +30,13 @@ def _is_numeric_value(arr, var):
"""
from aesara.link.c.type import _cdata_type
from aesara.tensor.random.type import RandomType
if isinstance(arr, _cdata_type):
return False
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False
elif var and getattr(var.tag, "is_rng", False):
elif var and isinstance(var.type, RandomType):
return False
elif isinstance(arr, slice):
return False
......
......@@ -926,8 +926,6 @@ class MRG_RandomStream:
size=size,
nstreams=orig_nstreams,
)
# Add a reference to distinguish from other shared variables
node_rstate.tag.is_rng = True
r = u * (high - low) + low
if u.type.broadcastable != r.type.broadcastable:
......
......@@ -251,20 +251,19 @@ class RandomStream:
# Generate a new random state
seed = int(self.gen_seedgen.integers(2**30))
random_state_variable = shared(self.rng_ctor(seed))
# Distinguish it from other shared variables (why?)
random_state_variable.tag.is_rng = True
rng = shared(self.rng_ctor(seed), borrow=True)
# Generate the sample
out = op(*args, **kwargs, rng=random_state_variable)
out.rng = random_state_variable
out = op(*args, **kwargs, rng=rng)
# This is the value that should be used to replace the old state
# (i.e. `rng`) after `out` is sampled/evaluated.
# The updates mechanism in `aesara.function` is supposed to perform
# this replace action.
new_rng = out.owner.outputs[0]
# Update the tracked states
new_r = out.owner.outputs[0]
out.update = (random_state_variable, new_r)
self.state_updates.append(out.update)
self.state_updates.append((rng, new_rng))
random_state_variable.default_update = new_r
rng.default_update = new_rng
return out
......@@ -97,7 +97,6 @@ class TestSharedRandomStream:
assert np.all(f() != f())
assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5)
assert isinstance(rv_u.rng.get_value(borrow=True), np.random.Generator)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_basics(self, rng_ctor):
......@@ -109,8 +108,7 @@ class TestSharedRandomStream:
with pytest.raises(AttributeError):
random.blah
# test if standard_normal is available in the namespace, See: GH issue #528
random.standard_normal
assert hasattr(random, "standard_normal")
with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
......@@ -223,7 +221,7 @@ class TestSharedRandomStream:
# Explicit updates #2
random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_c = random_c.uniform(0, 1, size=(2, 2))
fn_c = function([], out_c, updates=[out_c.update])
fn_c = function([], out_c, updates=random_c.state_updates)
fn_c_val0 = fn_c()
fn_c_val1 = fn_c()
assert np.all(fn_c_val0 == fn_a_val0)
......@@ -241,7 +239,7 @@ class TestSharedRandomStream:
# No updates for out
random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_e = random_e.uniform(0, 1, size=(2, 2))
fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e = function([], out_e, no_default_updates=[random_e.state_updates[0][0]])
fn_e_val0 = fn_e()
fn_e_val1 = fn_e()
assert np.all(fn_e_val0 == fn_a_val0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论