提交 ebe518fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Tweak RandomGenerator tests

上级 88835278
......@@ -447,6 +447,8 @@ class AbstractRNGConstructor(Op):
def make_node(self, seed=None):
if seed is None:
seed = NoneConst
elif isinstance(seed, Variable) and isinstance(seed.type, NoneTypeT):
pass
else:
seed = as_tensor_variable(seed)
inputs = [seed]
......
......@@ -3,7 +3,9 @@ import pytest
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.compile import get_default_mode
from pytensor.graph.replace import vectorize_graph
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
......@@ -157,26 +159,30 @@ def test_RandomVariable_floatX(strict_test_value_flags):
assert test_rv_op(0, 1).dtype == new_floatX
@pytest.mark.parametrize(
"seed, maker_op, numpy_res",
[
(3, default_rng, np.random.default_rng(3)),
],
)
def test_random_maker_op(strict_test_value_flags, seed, maker_op, numpy_res):
seed = pt.as_tensor_variable(seed)
z = function(inputs=[], outputs=[maker_op(seed)])()
aes_res = z[0]
assert maker_op.random_type.values_eq(aes_res, numpy_res)
def test_default_rng_op():
seed = pt.scalar(dtype="int64")
res = function(inputs=[seed], outputs=default_rng(seed))(3)
expected_res = np.random.default_rng(3)
assert default_rng.random_type.values_eq(res, expected_res)
def test_random_maker_ops_no_seed(strict_test_value_flags):
def test_random_maker_ops_none_seed():
# Testing the initialization when seed=None
# Since internal states randomly generated,
# we just check the output classes
z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.Generator)
seed = none_type_t()
res = function(inputs=[seed], outputs=default_rng(seed))(None)
assert isinstance(res, np.random.Generator)
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba cannot lower default_rng as a literal",
)
def test_constant_rng_op():
res = function(inputs=[], outputs=default_rng(3))()
expected_res = np.random.default_rng(3)
assert default_rng.random_type.values_eq(res, expected_res)
def test_RandomVariable_incompatible_size(strict_test_value_flags):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论