提交 2b61312b authored 作者: zoj's avatar zoj 提交者: Thomas Wiecki

Expose standard_normal to RandomStrea

上级 57a1eb72
import abc import abc
import functools
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -117,6 +118,11 @@ class NormalRV(RandomVariable): ...@@ -117,6 +118,11 @@ class NormalRV(RandomVariable):
normal = NormalRV() normal = NormalRV()
# used as an alias of normal(loc=0, scale=1) in order to be consistent with np.random.RandomState
standard_normal = functools.update_wrapper(
functools.partial(normal, loc=0.0, scale=1.0),
normal
)
class HalfNormalRV(ScipyRandomVariable): class HalfNormalRV(ScipyRandomVariable):
...@@ -815,4 +821,5 @@ __all__ = [ ...@@ -815,4 +821,5 @@ __all__ = [
"beta", "beta",
"triangular", "triangular",
"uniform", "uniform",
"standard_normal",
] ]
...@@ -47,6 +47,7 @@ from aesara.tensor.random.basic import ( ...@@ -47,6 +47,7 @@ from aesara.tensor.random.basic import (
poisson, poisson,
polyagamma, polyagamma,
randint, randint,
standard_normal,
triangular, triangular,
truncexpon, truncexpon,
uniform, uniform,
...@@ -290,7 +291,7 @@ def test_normal_samples(mean, sigma, size): ...@@ -290,7 +291,7 @@ def test_normal_samples(mean, sigma, size):
def test_normal_default_args(): def test_normal_default_args():
rv_numpy_tester(normal) rv_numpy_tester(standard_normal)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -109,6 +109,9 @@ class TestSharedRandomStream: ...@@ -109,6 +109,9 @@ class TestSharedRandomStream:
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
random.blah random.blah
# test if standard_normal is available in the namespace, See: GH issue #528
random.standard_normal
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor) np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
np_random.ndarray np_random.ndarray
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论