提交 e7ce9fc8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

StandandardNormalRV is now just a helper function

上级 aa89e44b
......@@ -162,7 +162,6 @@ def jax_sample_fn_generic(op, node):
@jax_sample_fn.register(ptr.LaplaceRV)
@jax_sample_fn.register(ptr.LogisticRV)
@jax_sample_fn.register(ptr.NormalRV)
@jax_sample_fn.register(ptr.StandardNormalRV)
def jax_sample_fn_loc_scale(op, node):
"""JAX implementation of random variables in the loc-scale families.
......
......@@ -281,38 +281,24 @@ class NormalRV(RandomVariable):
normal = NormalRV()
class StandardNormalRV(NormalRV):
r"""A standard normal continuous random variable.
def standard_normal(*, size=None, rng=None, dtype=None):
"""Draw samples from a standard normal distribution.
The probability density function for `standard_normal` is:
Signature
---------
.. math::
`nil -> ()`
f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}}
Parameters
----------
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.
"""
def __call__(self, size=None, **kwargs):
"""Draw samples from a standard normal distribution.
Signature
---------
`nil -> ()`
Parameters
----------
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.
"""
return super().__call__(loc=0.0, scale=1.0, size=size, **kwargs)
standard_normal = StandardNormalRV()
return normal(0.0, 1.0, size=size, rng=rng, dtype=dtype)
class HalfNormalRV(ScipyRandomVariable):
......
......@@ -218,9 +218,9 @@ class RandomStream:
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self
self.namespaces = [basic]
self.namespaces = [(basic, set(basic.__all__))]
else:
self.namespaces = [namespace]
self.namespaces = [(namespace, set(namespace.__all__))]
self.default_instance_seed = seed
self.state_updates = []
......@@ -235,22 +235,20 @@ class RandomStream:
def __getattr__(self, obj):
ns_obj = next(
(getattr(ns, obj) for ns in self.namespaces if hasattr(ns, obj)), None
(
getattr(ns, obj)
for ns, all_ in self.namespaces
if obj in all_ and hasattr(ns, obj)
),
None,
)
if ns_obj is None:
raise AttributeError(f"No attribute {obj}.")
from pytensor.tensor.random.op import RandomVariable
if isinstance(ns_obj, RandomVariable):
@wraps(ns_obj)
def meta_obj(*args, **kwargs):
return self.gen(ns_obj, *args, **kwargs)
else:
raise AttributeError(f"No attribute {obj}.")
@wraps(ns_obj)
def meta_obj(*args, **kwargs):
return self.gen(ns_obj, *args, **kwargs)
setattr(self, obj, meta_obj)
return getattr(self, obj)
......
......@@ -114,7 +114,7 @@ class TestSharedRandomStream:
assert hasattr(random, "standard_normal")
with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
np_random = RandomStream(namespace=np.random, rng_ctor=rng_ctor)
np_random.ndarray
fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论