提交 33eaccac authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused/deprecated variable attributes added by RandomStream

上级 75b4dd18
...@@ -30,12 +30,13 @@ def _is_numeric_value(arr, var): ...@@ -30,12 +30,13 @@ def _is_numeric_value(arr, var):
""" """
from aesara.link.c.type import _cdata_type from aesara.link.c.type import _cdata_type
from aesara.tensor.random.type import RandomType
if isinstance(arr, _cdata_type): if isinstance(arr, _cdata_type):
return False return False
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)): elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False return False
elif var and getattr(var.tag, "is_rng", False): elif var and isinstance(var.type, RandomType):
return False return False
elif isinstance(arr, slice): elif isinstance(arr, slice):
return False return False
......
...@@ -926,8 +926,6 @@ class MRG_RandomStream: ...@@ -926,8 +926,6 @@ class MRG_RandomStream:
size=size, size=size,
nstreams=orig_nstreams, nstreams=orig_nstreams,
) )
# Add a reference to distinguish from other shared variables
node_rstate.tag.is_rng = True
r = u * (high - low) + low r = u * (high - low) + low
if u.type.broadcastable != r.type.broadcastable: if u.type.broadcastable != r.type.broadcastable:
......
...@@ -251,20 +251,19 @@ class RandomStream: ...@@ -251,20 +251,19 @@ class RandomStream:
# Generate a new random state # Generate a new random state
seed = int(self.gen_seedgen.integers(2**30)) seed = int(self.gen_seedgen.integers(2**30))
random_state_variable = shared(self.rng_ctor(seed)) rng = shared(self.rng_ctor(seed), borrow=True)
# Distinguish it from other shared variables (why?)
random_state_variable.tag.is_rng = True
# Generate the sample # Generate the sample
out = op(*args, **kwargs, rng=random_state_variable) out = op(*args, **kwargs, rng=rng)
out.rng = random_state_variable
# This is the value that should be used to replace the old state
# (i.e. `rng`) after `out` is sampled/evaluated.
# The updates mechanism in `aesara.function` is supposed to perform
# this replace action.
new_rng = out.owner.outputs[0]
# Update the tracked states self.state_updates.append((rng, new_rng))
new_r = out.owner.outputs[0]
out.update = (random_state_variable, new_r)
self.state_updates.append(out.update)
random_state_variable.default_update = new_r rng.default_update = new_rng
return out return out
...@@ -97,7 +97,6 @@ class TestSharedRandomStream: ...@@ -97,7 +97,6 @@ class TestSharedRandomStream:
assert np.all(f() != f()) assert np.all(f() != f())
assert np.all(g() == g()) assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5) assert np.all(abs(nearly_zeros()) < 1e-5)
assert isinstance(rv_u.rng.get_value(borrow=True), np.random.Generator)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_basics(self, rng_ctor): def test_basics(self, rng_ctor):
...@@ -109,8 +108,7 @@ class TestSharedRandomStream: ...@@ -109,8 +108,7 @@ class TestSharedRandomStream:
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
random.blah random.blah
# test if standard_normal is available in the namespace, See: GH issue #528 assert hasattr(random, "standard_normal")
random.standard_normal
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor) np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
...@@ -223,7 +221,7 @@ class TestSharedRandomStream: ...@@ -223,7 +221,7 @@ class TestSharedRandomStream:
# Explicit updates #2 # Explicit updates #2
random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_c = random_c.uniform(0, 1, size=(2, 2)) out_c = random_c.uniform(0, 1, size=(2, 2))
fn_c = function([], out_c, updates=[out_c.update]) fn_c = function([], out_c, updates=random_c.state_updates)
fn_c_val0 = fn_c() fn_c_val0 = fn_c()
fn_c_val1 = fn_c() fn_c_val1 = fn_c()
assert np.all(fn_c_val0 == fn_a_val0) assert np.all(fn_c_val0 == fn_a_val0)
...@@ -241,7 +239,7 @@ class TestSharedRandomStream: ...@@ -241,7 +239,7 @@ class TestSharedRandomStream:
# No updates for out # No updates for out
random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_e = random_e.uniform(0, 1, size=(2, 2)) out_e = random_e.uniform(0, 1, size=(2, 2))
fn_e = function([], out_e, no_default_updates=[out_e.rng]) fn_e = function([], out_e, no_default_updates=[random_e.state_updates[0][0]])
fn_e_val0 = fn_e() fn_e_val0 = fn_e()
fn_e_val1 = fn_e() fn_e_val1 = fn_e()
assert np.all(fn_e_val0 == fn_a_val0) assert np.all(fn_e_val0 == fn_a_val0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论