提交 ebe518fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Tweak RandomGenerator tests

上级 88835278
...@@ -447,6 +447,8 @@ class AbstractRNGConstructor(Op): ...@@ -447,6 +447,8 @@ class AbstractRNGConstructor(Op):
def make_node(self, seed=None): def make_node(self, seed=None):
if seed is None: if seed is None:
seed = NoneConst seed = NoneConst
elif isinstance(seed, Variable) and isinstance(seed.type, NoneTypeT):
pass
else: else:
seed = as_tensor_variable(seed) seed = as_tensor_variable(seed)
inputs = [seed] inputs = [seed]
......
...@@ -3,7 +3,9 @@ import pytest ...@@ -3,7 +3,9 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_default_mode
from pytensor.graph.replace import vectorize_graph from pytensor.graph.replace import vectorize_graph
from pytensor.link.numba import NumbaLinker
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
...@@ -157,26 +159,30 @@ def test_RandomVariable_floatX(strict_test_value_flags): ...@@ -157,26 +159,30 @@ def test_RandomVariable_floatX(strict_test_value_flags):
assert test_rv_op(0, 1).dtype == new_floatX assert test_rv_op(0, 1).dtype == new_floatX
@pytest.mark.parametrize( def test_default_rng_op():
"seed, maker_op, numpy_res", seed = pt.scalar(dtype="int64")
[ res = function(inputs=[seed], outputs=default_rng(seed))(3)
(3, default_rng, np.random.default_rng(3)), expected_res = np.random.default_rng(3)
], assert default_rng.random_type.values_eq(res, expected_res)
)
def test_random_maker_op(strict_test_value_flags, seed, maker_op, numpy_res):
seed = pt.as_tensor_variable(seed)
z = function(inputs=[], outputs=[maker_op(seed)])()
aes_res = z[0]
assert maker_op.random_type.values_eq(aes_res, numpy_res)
def test_random_maker_ops_no_seed(strict_test_value_flags): def test_random_maker_ops_none_seed():
# Testing the initialization when seed=None # Testing the initialization when seed=None
# Since internal states randomly generated, # Since internal states randomly generated,
# we just check the output classes # we just check the output classes
z = function(inputs=[], outputs=[default_rng()])() seed = none_type_t()
aes_res = z[0] res = function(inputs=[seed], outputs=default_rng(seed))(None)
assert isinstance(aes_res, np.random.Generator) assert isinstance(res, np.random.Generator)
@pytest.mark.xfail(
condition=isinstance(get_default_mode().linker, NumbaLinker),
reason="Numba cannot lower default_rng as a literal",
)
def test_constant_rng_op():
res = function(inputs=[], outputs=default_rng(3))()
expected_res = np.random.default_rng(3)
assert default_rng.random_type.values_eq(res, expected_res)
def test_RandomVariable_incompatible_size(strict_test_value_flags): def test_RandomVariable_incompatible_size(strict_test_value_flags):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论