提交 a0604e52 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added RandomMakerOps: Constructors for RandomType Variables

上级 ca995ae2
...@@ -2,4 +2,5 @@ ...@@ -2,4 +2,5 @@
import aesara.tensor.random.opt import aesara.tensor.random.opt
import aesara.tensor.random.utils import aesara.tensor.random.utils
from aesara.tensor.random.basic import * from aesara.tensor.random.basic import *
from aesara.tensor.random.op import RandomState, default_rng
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
...@@ -17,10 +17,11 @@ from aesara.tensor.basic import ( ...@@ -17,10 +17,11 @@ from aesara.tensor.basic import (
get_vector_length, get_vector_length,
infer_broadcastable, infer_broadcastable,
) )
from aesara.tensor.random.type import RandomType from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple from aesara.tensor.shape import shape_tuple
from aesara.tensor.type import TensorType, all_dtypes from aesara.tensor.type import TensorType, all_dtypes
from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
...@@ -399,3 +400,36 @@ class RandomVariable(Op): ...@@ -399,3 +400,36 @@ class RandomVariable(Op):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None for i in eval_points] return [None for i in eval_points]
class AbstractRNGConstructor(Op):
def make_node(self, seed=None):
if seed is None:
seed = NoneConst
else:
seed = as_tensor_variable(seed)
inputs = [seed]
outputs = [self.random_type()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, output_storage):
(seed,) = inputs
if seed is not None and seed.size == 1:
seed = int(seed)
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)
class RandomStateConstructor(AbstractRNGConstructor):
random_type = RandomStateType()
random_constructor = "RandomState"
RandomState = RandomStateConstructor()
class DefaultGeneratorMakerOp(AbstractRNGConstructor):
random_type = RandomGeneratorType()
random_constructor = "default_rng"
default_rng = DefaultGeneratorMakerOp()
import numpy as np import numpy as np
from pytest import fixture, raises import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara import config from aesara import config, function
from aesara.gradient import NullTypeGradError, grad from aesara.gradient import NullTypeGradError, grad
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.random.op import RandomVariable, default_shape_from_params from aesara.tensor.random.op import (
RandomState,
RandomVariable,
default_rng,
default_shape_from_params,
)
from aesara.tensor.shape import specify_shape from aesara.tensor.shape import specify_shape
from aesara.tensor.type import all_dtypes, iscalar, tensor from aesara.tensor.type import all_dtypes, iscalar, tensor
@fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def set_aesara_flags(): def set_aesara_flags():
with config.change_flags(cxx="", compute_test_value="raise"): with config.change_flags(cxx="", compute_test_value="raise"):
yield yield
def test_default_shape_from_params(): def test_default_shape_from_params():
with raises(ValueError, match="^ndim_supp*"): with pytest.raises(ValueError, match="^ndim_supp*"):
default_shape_from_params(0, (np.array([1, 2]), 0)) default_shape_from_params(0, (np.array([1, 2]), 0))
res = default_shape_from_params(1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0) res = default_shape_from_params(1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0)
...@@ -27,7 +32,7 @@ def test_default_shape_from_params(): ...@@ -27,7 +32,7 @@ def test_default_shape_from_params():
res = default_shape_from_params(1, (np.array([1, 2]), 0), param_shapes=((2,), ())) res = default_shape_from_params(1, (np.array([1, 2]), 0), param_shapes=((2,), ()))
assert res == (2,) assert res == (2,)
with raises(ValueError, match="^Reference parameter*"): with pytest.raises(ValueError, match="^Reference parameter*"):
default_shape_from_params(1, (np.array(1),), rep_param_idx=0) default_shape_from_params(1, (np.array(1),), rep_param_idx=0)
res = default_shape_from_params( res = default_shape_from_params(
...@@ -51,7 +56,7 @@ def test_RandomVariable_basics(): ...@@ -51,7 +56,7 @@ def test_RandomVariable_basics():
assert str_res == "normal_rv{0, (0, 0), float32, True}" assert str_res == "normal_rv{0, (0, 0), float32, True}"
# `ndims_params` should be a `Sequence` type # `ndims_params` should be a `Sequence` type
with raises(TypeError, match="^Parameter ndims_params*"): with pytest.raises(TypeError, match="^Parameter ndims_params*"):
RandomVariable( RandomVariable(
"normal", "normal",
0, 0,
...@@ -61,7 +66,7 @@ def test_RandomVariable_basics(): ...@@ -61,7 +66,7 @@ def test_RandomVariable_basics():
) )
# `size` should be a `Sequence` type # `size` should be a `Sequence` type
with raises(TypeError, match="^Parameter size*"): with pytest.raises(TypeError, match="^Parameter size*"):
RandomVariable( RandomVariable(
"normal", "normal",
0, 0,
...@@ -71,7 +76,7 @@ def test_RandomVariable_basics(): ...@@ -71,7 +76,7 @@ def test_RandomVariable_basics():
)(0, 1, size={1, 2}) )(0, 1, size={1, 2})
# No dtype # No dtype
with raises(TypeError, match="^dtype*"): with pytest.raises(TypeError, match="^dtype*"):
RandomVariable( RandomVariable(
"normal", "normal",
0, 0,
...@@ -94,7 +99,7 @@ def test_RandomVariable_basics(): ...@@ -94,7 +99,7 @@ def test_RandomVariable_basics():
# A no-params `RandomVariable` # A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=()) rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=())
with raises(TypeError): with pytest.raises(TypeError):
rv.make_node(rng=1) rv.make_node(rng=1)
# `RandomVariable._infer_shape` should handle no parameters # `RandomVariable._infer_shape` should handle no parameters
...@@ -109,7 +114,7 @@ def test_RandomVariable_basics(): ...@@ -109,7 +114,7 @@ def test_RandomVariable_basics():
assert rv_out.dtype == dtype_1 assert rv_out.dtype == dtype_1
with raises(NullTypeGradError): with pytest.raises(NullTypeGradError):
grad(rv_out, [rv_node.inputs[0]]) grad(rv_out, [rv_node.inputs[0]])
...@@ -178,3 +183,30 @@ def test_RandomVariable_floatX(): ...@@ -178,3 +183,30 @@ def test_RandomVariable_floatX():
with config.change_flags(floatX=new_floatX): with config.change_flags(floatX=new_floatX):
assert test_rv_op(0, 1).dtype == new_floatX assert test_rv_op(0, 1).dtype == new_floatX
@pytest.mark.parametrize(
"seed, maker_op, numpy_res",
[
(3, RandomState, np.random.RandomState(3)),
(3, default_rng, np.random.default_rng(3)),
],
)
def test_random_maker_op(seed, maker_op, numpy_res):
seed = at.as_tensor_variable(seed)
z = function(inputs=[], outputs=[maker_op(seed)])()
aes_res = z[0]
assert maker_op.random_type.values_eq(aes_res, numpy_res)
def test_random_maker_ops_no_seed():
# Testing the initialization when seed=None
# Since internal states randomly generated,
# we just check the output classes
z = function(inputs=[], outputs=[RandomState()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.RandomState)
z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.Generator)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论