Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5d4b0c4b
提交
5d4b0c4b
authored
5月 09, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove RandomState type in remaining backends
上级
14da898c
显示空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
50 行增加
和
392 行删除
+50
-392
random.py
pytensor/link/jax/dispatch/random.py
+1
-11
random.py
pytensor/link/numba/dispatch/random.py
+0
-4
__init__.py
pytensor/tensor/random/__init__.py
+1
-1
basic.py
pytensor/tensor/random/basic.py
+1
-67
op.py
pytensor/tensor/random/op.py
+4
-13
type.py
pytensor/tensor/random/type.py
+0
-91
utils.py
pytensor/tensor/random/utils.py
+1
-9
var.py
pytensor/tensor/random/var.py
+6
-10
test_random.py
tests/link/jax/test_random.py
+22
-42
test_basic.py
tests/scan/test_basic.py
+1
-4
test_basic.py
tests/tensor/random/test_basic.py
+0
-22
test_op.py
tests/tensor/random/test_op.py
+1
-6
test_type.py
tests/tensor/random/test_type.py
+2
-96
test_utils.py
tests/tensor/random/test_utils.py
+6
-6
test_var.py
tests/tensor/random/test_var.py
+4
-10
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
5d4b0c4b
...
@@ -2,7 +2,7 @@ from functools import singledispatch
...
@@ -2,7 +2,7 @@ from functools import singledispatch
import
jax
import
jax
import
numpy
as
np
import
numpy
as
np
from
numpy.random
import
Generator
,
RandomState
from
numpy.random
import
Generator
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
from
numpy.random.bit_generator
import
(
# type: ignore[attr-defined]
_coerce_to_uint32_array
,
_coerce_to_uint32_array
,
)
)
...
@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
...
@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
raise
NotImplementedError
(
SIZE_NOT_COMPATIBLE
)
raise
NotImplementedError
(
SIZE_NOT_COMPATIBLE
)
@jax_typify.register
(
RandomState
)
def
jax_typify_RandomState
(
state
,
**
kwargs
):
state
=
state
.
get_state
(
legacy
=
False
)
state
[
"bit_generator"
]
=
numpy_bit_gens
[
state
[
"bit_generator"
]]
# XXX: Is this a reasonable approach?
state
[
"jax_state"
]
=
state
[
"state"
][
"key"
][
0
:
2
]
return
state
@jax_typify.register
(
Generator
)
@jax_typify.register
(
Generator
)
def
jax_typify_Generator
(
rng
,
**
kwargs
):
def
jax_typify_Generator
(
rng
,
**
kwargs
):
state
=
rng
.
__getstate__
()
state
=
rng
.
__getstate__
()
...
@@ -214,7 +205,6 @@ def jax_sample_fn_categorical(op, node):
...
@@ -214,7 +205,6 @@ def jax_sample_fn_categorical(op, node):
return
sample_fn
return
sample_fn
@jax_sample_fn.register
(
ptr
.
RandIntRV
)
@jax_sample_fn.register
(
ptr
.
IntegersRV
)
@jax_sample_fn.register
(
ptr
.
IntegersRV
)
@jax_sample_fn.register
(
ptr
.
UniformRV
)
@jax_sample_fn.register
(
ptr
.
UniformRV
)
def
jax_sample_fn_uniform
(
op
,
node
):
def
jax_sample_fn_uniform
(
op
,
node
):
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
5d4b0c4b
...
@@ -25,7 +25,6 @@ from pytensor.link.utils import (
...
@@ -25,7 +25,6 @@ from pytensor.link.utils import (
)
)
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.random.op
import
RandomVariable
,
RandomVariableWithCoreShape
from
pytensor.tensor.random.op
import
RandomVariable
,
RandomVariableWithCoreShape
from
pytensor.tensor.random.type
import
RandomStateType
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.utils
import
_parse_gufunc_signature
from
pytensor.tensor.utils
import
_parse_gufunc_signature
...
@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
...
@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
[
rv_node
]
=
op
.
fgraph
.
apply_nodes
[
rv_node
]
=
op
.
fgraph
.
apply_nodes
rv_op
:
RandomVariable
=
rv_node
.
op
rv_op
:
RandomVariable
=
rv_node
.
op
rng_param
=
rv_op
.
rng_param
(
rv_node
)
if
isinstance
(
rng_param
.
type
,
RandomStateType
):
raise
TypeError
(
"Numba does not support NumPy `RandomStateType`s"
)
size
=
rv_op
.
size_param
(
rv_node
)
size
=
rv_op
.
size_param
(
rv_node
)
dist_params
=
rv_op
.
dist_params
(
rv_node
)
dist_params
=
rv_op
.
dist_params
(
rv_node
)
size_len
=
None
if
isinstance
(
size
.
type
,
NoneTypeT
)
else
get_vector_length
(
size
)
size_len
=
None
if
isinstance
(
size
.
type
,
NoneTypeT
)
else
get_vector_length
(
size
)
...
...
pytensor/tensor/random/__init__.py
浏览文件 @
5d4b0c4b
...
@@ -2,5 +2,5 @@
...
@@ -2,5 +2,5 @@
import
pytensor.tensor.random.rewriting
import
pytensor.tensor.random.rewriting
import
pytensor.tensor.random.utils
import
pytensor.tensor.random.utils
from
pytensor.tensor.random.basic
import
*
from
pytensor.tensor.random.basic
import
*
from
pytensor.tensor.random.op
import
RandomState
,
default_rng
from
pytensor.tensor.random.op
import
default_rng
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.random.utils
import
RandomStream
pytensor/tensor/random/basic.py
浏览文件 @
5d4b0c4b
...
@@ -9,15 +9,10 @@ from pytensor.tensor import get_vector_length, specify_shape
...
@@ -9,15 +9,10 @@ from pytensor.tensor import get_vector_length, specify_shape
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.basic
import
as_tensor_variable
from
pytensor.tensor.math
import
sqrt
from
pytensor.tensor.math
import
sqrt
from
pytensor.tensor.random.op
import
RandomVariable
from
pytensor.tensor.random.op
import
RandomVariable
from
pytensor.tensor.random.type
import
RandomGeneratorType
,
RandomStateType
from
pytensor.tensor.random.utils
import
(
from
pytensor.tensor.random.utils
import
(
broadcast_params
,
broadcast_params
,
normalize_size_param
,
normalize_size_param
,
)
)
from
pytensor.tensor.random.var
import
(
RandomGeneratorSharedVariable
,
RandomStateSharedVariable
,
)
try
:
try
:
...
@@ -645,7 +640,7 @@ class GumbelRV(ScipyRandomVariable):
...
@@ -645,7 +640,7 @@ class GumbelRV(ScipyRandomVariable):
@classmethod
@classmethod
def
rng_fn_scipy
(
def
rng_fn_scipy
(
cls
,
cls
,
rng
:
np
.
random
.
Generator
|
np
.
random
.
RandomState
,
rng
:
np
.
random
.
Generator
,
loc
:
np
.
ndarray
|
float
,
loc
:
np
.
ndarray
|
float
,
scale
:
np
.
ndarray
|
float
,
scale
:
np
.
ndarray
|
float
,
size
:
list
[
int
]
|
int
|
None
,
size
:
list
[
int
]
|
int
|
None
,
...
@@ -1880,58 +1875,6 @@ class CategoricalRV(RandomVariable):
...
@@ -1880,58 +1875,6 @@ class CategoricalRV(RandomVariable):
categorical
=
CategoricalRV
()
categorical
=
CategoricalRV
()
class
RandIntRV
(
RandomVariable
):
r"""A discrete uniform random variable.
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
"""
name
=
"randint"
signature
=
"(),()->()"
dtype
=
"int64"
_print_name
=
(
"randint"
,
"
\\
operatorname{randint}"
)
def
__call__
(
self
,
low
,
high
=
None
,
size
=
None
,
**
kwargs
):
r"""Draw samples from a discrete uniform distribution.
Signature
---------
`() -> ()`
Parameters
----------
low
Lower boundary of the output interval. All values generated will
be greater than or equal to `low`, unless `high=None`, in which case
all values generated are greater than or equal to `0` and
smaller than `low` (exclusive).
high
Upper boundary of the output interval. All values generated
will be smaller than `high` (exclusive).
size
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
independent, identically distributed samples are
returned. Default is `None`, in which case a single
sample is returned.
"""
if
high
is
None
:
low
,
high
=
0
,
low
return
super
()
.
__call__
(
low
,
high
,
size
=
size
,
**
kwargs
)
def
make_node
(
self
,
rng
,
*
args
,
**
kwargs
):
if
not
isinstance
(
getattr
(
rng
,
"type"
,
None
),
RandomStateType
|
RandomStateSharedVariable
):
raise
TypeError
(
"`randint` is only available for `RandomStateType`s"
)
return
super
()
.
make_node
(
rng
,
*
args
,
**
kwargs
)
randint
=
RandIntRV
()
class
IntegersRV
(
RandomVariable
):
class
IntegersRV
(
RandomVariable
):
r"""A discrete uniform random variable.
r"""A discrete uniform random variable.
...
@@ -1971,14 +1914,6 @@ class IntegersRV(RandomVariable):
...
@@ -1971,14 +1914,6 @@ class IntegersRV(RandomVariable):
low
,
high
=
0
,
low
low
,
high
=
0
,
low
return
super
()
.
__call__
(
low
,
high
,
size
=
size
,
**
kwargs
)
return
super
()
.
__call__
(
low
,
high
,
size
=
size
,
**
kwargs
)
def
make_node
(
self
,
rng
,
*
args
,
**
kwargs
):
if
not
isinstance
(
getattr
(
rng
,
"type"
,
None
),
RandomGeneratorType
|
RandomGeneratorSharedVariable
,
):
raise
TypeError
(
"`integers` is only available for `RandomGeneratorType`s"
)
return
super
()
.
make_node
(
rng
,
*
args
,
**
kwargs
)
integers
=
IntegersRV
()
integers
=
IntegersRV
()
...
@@ -2201,7 +2136,6 @@ __all__ = [
...
@@ -2201,7 +2136,6 @@ __all__ = [
"permutation"
,
"permutation"
,
"choice"
,
"choice"
,
"integers"
,
"integers"
,
"randint"
,
"categorical"
,
"categorical"
,
"multinomial"
,
"multinomial"
,
"betabinom"
,
"betabinom"
,
...
...
pytensor/tensor/random/op.py
浏览文件 @
5d4b0c4b
...
@@ -20,7 +20,7 @@ from pytensor.tensor.basic import (
...
@@ -20,7 +20,7 @@ from pytensor.tensor.basic import (
infer_static_shape
,
infer_static_shape
,
)
)
from
pytensor.tensor.blockwise
import
OpWithCoreShape
from
pytensor.tensor.blockwise
import
OpWithCoreShape
from
pytensor.tensor.random.type
import
RandomGeneratorType
,
Random
StateType
,
Random
Type
from
pytensor.tensor.random.type
import
RandomGeneratorType
,
RandomType
from
pytensor.tensor.random.utils
import
(
from
pytensor.tensor.random.utils
import
(
compute_batch_shape
,
compute_batch_shape
,
explicit_expand_dims
,
explicit_expand_dims
,
...
@@ -324,9 +324,8 @@ class RandomVariable(Op):
...
@@ -324,9 +324,8 @@ class RandomVariable(Op):
Parameters
Parameters
----------
----------
rng: RandomGeneratorType or RandomStateType
rng: RandomGeneratorType
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
new one, if `None`.
size: int or Sequence
size: int or Sequence
NumPy-like size parameter.
NumPy-like size parameter.
dtype: str
dtype: str
...
@@ -354,7 +353,7 @@ class RandomVariable(Op):
...
@@ -354,7 +353,7 @@ class RandomVariable(Op):
rng
=
pytensor
.
shared
(
np
.
random
.
default_rng
())
rng
=
pytensor
.
shared
(
np
.
random
.
default_rng
())
elif
not
isinstance
(
rng
.
type
,
RandomType
):
elif
not
isinstance
(
rng
.
type
,
RandomType
):
raise
TypeError
(
raise
TypeError
(
"The type of rng should be an instance of
either RandomGeneratorType or RandomStateType
"
"The type of rng should be an instance of
RandomGeneratorType
"
)
)
inferred_shape
=
self
.
_infer_shape
(
size
,
dist_params
)
inferred_shape
=
self
.
_infer_shape
(
size
,
dist_params
)
...
@@ -436,14 +435,6 @@ class AbstractRNGConstructor(Op):
...
@@ -436,14 +435,6 @@ class AbstractRNGConstructor(Op):
output_storage
[
0
][
0
]
=
getattr
(
np
.
random
,
self
.
random_constructor
)(
seed
=
seed
)
output_storage
[
0
][
0
]
=
getattr
(
np
.
random
,
self
.
random_constructor
)(
seed
=
seed
)
class
RandomStateConstructor
(
AbstractRNGConstructor
):
random_type
=
RandomStateType
()
random_constructor
=
"RandomState"
RandomState
=
RandomStateConstructor
()
class
DefaultGeneratorMakerOp
(
AbstractRNGConstructor
):
class
DefaultGeneratorMakerOp
(
AbstractRNGConstructor
):
random_type
=
RandomGeneratorType
()
random_type
=
RandomGeneratorType
()
random_constructor
=
"default_rng"
random_constructor
=
"default_rng"
...
...
pytensor/tensor/random/type.py
浏览文件 @
5d4b0c4b
...
@@ -31,97 +31,6 @@ class RandomType(Type[T]):
...
@@ -31,97 +31,6 @@ class RandomType(Type[T]):
return
a
.
_bit_generator
is
b
.
_bit_generator
# type: ignore[attr-defined]
return
a
.
_bit_generator
is
b
.
_bit_generator
# type: ignore[attr-defined]
class
RandomStateType
(
RandomType
[
np
.
random
.
RandomState
]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def
__repr__
(
self
):
return
"RandomStateType"
def
filter
(
self
,
data
,
strict
:
bool
=
False
,
allow_downcast
=
None
):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if
isinstance
(
data
,
np
.
random
.
RandomState
):
return
data
if
not
strict
and
isinstance
(
data
,
dict
):
gen_keys
=
[
"bit_generator"
,
"gauss"
,
"has_gauss"
,
"state"
]
state_keys
=
[
"key"
,
"pos"
]
for
key
in
gen_keys
:
if
key
not
in
data
:
raise
TypeError
()
for
key
in
state_keys
:
if
key
not
in
data
[
"state"
]:
raise
TypeError
()
state_key
=
data
[
"state"
][
"key"
]
if
state_key
.
shape
==
(
624
,)
and
state_key
.
dtype
==
np
.
uint32
:
# TODO: Add an option to convert to a `RandomState` instance?
return
data
raise
TypeError
()
@staticmethod
def
values_eq
(
a
,
b
):
sa
=
a
if
isinstance
(
a
,
dict
)
else
a
.
get_state
(
legacy
=
False
)
sb
=
b
if
isinstance
(
b
,
dict
)
else
b
.
get_state
(
legacy
=
False
)
def
_eq
(
sa
,
sb
):
for
key
in
sa
:
if
isinstance
(
sa
[
key
],
dict
):
if
not
_eq
(
sa
[
key
],
sb
[
key
]):
return
False
elif
isinstance
(
sa
[
key
],
np
.
ndarray
):
if
not
np
.
array_equal
(
sa
[
key
],
sb
[
key
]):
return
False
else
:
if
sa
[
key
]
!=
sb
[
key
]:
return
False
return
True
return
_eq
(
sa
,
sb
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
# Register `RandomStateType`'s C code for `ViewOp`.
pytensor
.
compile
.
register_view_op_c_code
(
RandomStateType
,
"""
Py_XDECREF(
%(oname)
s);
%(oname)
s =
%(iname)
s;
Py_XINCREF(
%(oname)
s);
"""
,
1
,
)
random_state_type
=
RandomStateType
()
class
RandomGeneratorType
(
RandomType
[
np
.
random
.
Generator
]):
class
RandomGeneratorType
(
RandomType
[
np
.
random
.
Generator
]):
r"""A Type wrapper for `numpy.random.Generator`.
r"""A Type wrapper for `numpy.random.Generator`.
...
...
pytensor/tensor/random/utils.py
浏览文件 @
5d4b0c4b
...
@@ -209,9 +209,7 @@ class RandomStream:
...
@@ -209,9 +209,7 @@ class RandomStream:
self
,
self
,
seed
:
int
|
None
=
None
,
seed
:
int
|
None
=
None
,
namespace
:
ModuleType
|
None
=
None
,
namespace
:
ModuleType
|
None
=
None
,
rng_ctor
:
Literal
[
rng_ctor
:
Literal
[
np
.
random
.
Generator
]
=
np
.
random
.
default_rng
,
np
.
random
.
RandomState
,
np
.
random
.
Generator
]
=
np
.
random
.
default_rng
,
):
):
if
namespace
is
None
:
if
namespace
is
None
:
from
pytensor.tensor.random
import
basic
# pylint: disable=import-self
from
pytensor.tensor.random
import
basic
# pylint: disable=import-self
...
@@ -223,12 +221,6 @@ class RandomStream:
...
@@ -223,12 +221,6 @@ class RandomStream:
self
.
default_instance_seed
=
seed
self
.
default_instance_seed
=
seed
self
.
state_updates
=
[]
self
.
state_updates
=
[]
self
.
gen_seedgen
=
np
.
random
.
SeedSequence
(
seed
)
self
.
gen_seedgen
=
np
.
random
.
SeedSequence
(
seed
)
if
isinstance
(
rng_ctor
,
type
)
and
issubclass
(
rng_ctor
,
np
.
random
.
RandomState
):
# The legacy state does not accept `SeedSequence`s directly
def
rng_ctor
(
seed
):
return
np
.
random
.
RandomState
(
np
.
random
.
MT19937
(
seed
))
self
.
rng_ctor
=
rng_ctor
self
.
rng_ctor
=
rng_ctor
def
__getattr__
(
self
,
obj
):
def
__getattr__
(
self
,
obj
):
...
...
pytensor/tensor/random/var.py
浏览文件 @
5d4b0c4b
...
@@ -3,17 +3,12 @@ import copy
...
@@ -3,17 +3,12 @@ import copy
import
numpy
as
np
import
numpy
as
np
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared_constructor
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared_constructor
from
pytensor.tensor.random.type
import
random_generator_type
,
random_state_type
from
pytensor.tensor.random.type
import
random_generator_type
class
RandomStateSharedVariable
(
SharedVariable
):
def
__str__
(
self
):
return
self
.
name
or
f
"RandomStateSharedVariable({self.container!r})"
class
RandomGeneratorSharedVariable
(
SharedVariable
):
class
RandomGeneratorSharedVariable
(
SharedVariable
):
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
name
or
f
"R
andomGeneratorSharedVariable
({self.container!r})"
return
self
.
name
or
f
"R
NG
({self.container!r})"
@shared_constructor.register
(
np
.
random
.
RandomState
)
@shared_constructor.register
(
np
.
random
.
RandomState
)
...
@@ -23,9 +18,10 @@ def randomgen_constructor(
...
@@ -23,9 +18,10 @@ def randomgen_constructor(
):
):
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if
isinstance
(
value
,
np
.
random
.
RandomState
):
if
isinstance
(
value
,
np
.
random
.
RandomState
):
rng_sv_type
=
RandomStateSharedVariable
raise
TypeError
(
rng_type
=
random_state_type
"`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
elif
isinstance
(
value
,
np
.
random
.
Generator
):
)
rng_sv_type
=
RandomGeneratorSharedVariable
rng_sv_type
=
RandomGeneratorSharedVariable
rng_type
=
random_generator_type
rng_type
=
random_generator_type
...
...
tests/link/jax/test_random.py
浏览文件 @
5d4b0c4b
...
@@ -49,7 +49,7 @@ def test_random_RandomStream():
...
@@ -49,7 +49,7 @@ def test_random_RandomStream():
assert
not
np
.
array_equal
(
jax_res_1
,
jax_res_2
)
assert
not
np
.
array_equal
(
jax_res_1
,
jax_res_2
)
@pytest.mark.parametrize
(
"rng_ctor"
,
(
np
.
random
.
RandomState
,
np
.
random
.
default_rng
))
@pytest.mark.parametrize
(
"rng_ctor"
,
(
np
.
random
.
default_rng
,
))
def
test_random_updates
(
rng_ctor
):
def
test_random_updates
(
rng_ctor
):
original_value
=
rng_ctor
(
seed
=
98
)
original_value
=
rng_ctor
(
seed
=
98
)
rng
=
shared
(
original_value
,
name
=
"original_rng"
,
borrow
=
False
)
rng
=
shared
(
original_value
,
name
=
"original_rng"
,
borrow
=
False
)
...
@@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality():
"poisson"
,
"poisson"
,
lambda
*
args
:
args
,
lambda
*
args
:
args
,
),
),
(
ptr
.
randint
,
[
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
np
.
int64
),
),
set_test_value
(
# high-value necessary since test on cdf
pt
.
lscalar
(),
np
.
array
(
1000
,
dtype
=
np
.
int64
),
),
],
(),
"randint"
,
lambda
*
args
:
args
,
),
(
(
ptr
.
integers
,
ptr
.
integers
,
[
[
...
@@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
...
@@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The parameters passed to the op.
The parameters passed to the op.
"""
"""
if
rv_op
is
ptr
.
integers
:
rng
=
shared
(
np
.
random
.
default_rng
(
29403
))
# Integers only accepts Generator, not RandomState
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
else
:
rng
=
shared
(
np
.
random
.
RandomState
(
29402
))
g
=
rv_op
(
*
dist_params
,
size
=
(
10000
,
*
base_size
),
rng
=
rng
)
g
=
rv_op
(
*
dist_params
,
size
=
(
10000
,
*
base_size
),
rng
=
rng
)
g_fn
=
compile_random_function
(
dist_params
,
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
(
dist_params
,
g
,
mode
=
jax_mode
)
samples
=
g_fn
(
samples
=
g_fn
(
...
@@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
...
@@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
@pytest.mark.parametrize
(
"size"
,
[(),
(
4
,)])
@pytest.mark.parametrize
(
"size"
,
[(),
(
4
,)])
def
test_random_bernoulli
(
size
):
def
test_random_bernoulli
(
size
):
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
bernoulli
(
0.5
,
size
=
(
1000
,
*
size
),
rng
=
rng
)
g
=
pt
.
random
.
bernoulli
(
0.5
,
size
=
(
1000
,
*
size
),
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
samples
=
g_fn
()
...
@@ -553,7 +533,7 @@ def test_random_bernoulli(size):
...
@@ -553,7 +533,7 @@ def test_random_bernoulli(size):
def
test_random_mvnormal
():
def
test_random_mvnormal
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
mu
=
np
.
ones
(
4
)
mu
=
np
.
ones
(
4
)
cov
=
np
.
eye
(
4
)
cov
=
np
.
eye
(
4
)
...
@@ -571,7 +551,7 @@ def test_random_mvnormal():
...
@@ -571,7 +551,7 @@ def test_random_mvnormal():
],
],
)
)
def
test_random_dirichlet
(
parameter
,
size
):
def
test_random_dirichlet
(
parameter
,
size
):
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
dirichlet
(
parameter
,
size
=
(
1000
,
*
size
),
rng
=
rng
)
g
=
pt
.
random
.
dirichlet
(
parameter
,
size
=
(
1000
,
*
size
),
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
samples
=
g_fn
()
...
@@ -598,7 +578,7 @@ def test_random_choice():
...
@@ -598,7 +578,7 @@ def test_random_choice():
assert
np
.
all
(
samples
%
2
==
1
)
assert
np
.
all
(
samples
%
2
==
1
)
# `replace=False` and `p is None`
# `replace=False` and `p is None`
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
choice
(
np
.
arange
(
100
),
replace
=
False
,
size
=
(
2
,
49
),
rng
=
rng
)
g
=
pt
.
random
.
choice
(
np
.
arange
(
100
),
replace
=
False
,
size
=
(
2
,
49
),
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
samples
=
g_fn
()
...
@@ -607,7 +587,7 @@ def test_random_choice():
...
@@ -607,7 +587,7 @@ def test_random_choice():
assert
len
(
np
.
unique
(
samples
))
==
98
assert
len
(
np
.
unique
(
samples
))
==
98
# `replace=False` and `p is not None`
# `replace=False` and `p is not None`
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
choice
(
g
=
pt
.
random
.
choice
(
8
,
8
,
p
=
np
.
array
([
0.25
,
0
,
0.25
,
0
,
0.25
,
0
,
0.25
,
0
]),
p
=
np
.
array
([
0.25
,
0
,
0.25
,
0
,
0.25
,
0
,
0.25
,
0
]),
...
@@ -625,7 +605,7 @@ def test_random_choice():
...
@@ -625,7 +605,7 @@ def test_random_choice():
def
test_random_categorical
():
def
test_random_categorical
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
categorical
(
0.25
*
np
.
ones
(
4
),
size
=
(
10000
,
4
),
rng
=
rng
)
g
=
pt
.
random
.
categorical
(
0.25
*
np
.
ones
(
4
),
size
=
(
10000
,
4
),
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
samples
=
g_fn
()
samples
=
g_fn
()
...
@@ -642,7 +622,7 @@ def test_random_categorical():
...
@@ -642,7 +622,7 @@ def test_random_categorical():
def
test_random_permutation
():
def
test_random_permutation
():
array
=
np
.
arange
(
4
)
array
=
np
.
arange
(
4
)
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
g
=
pt
.
random
.
permutation
(
array
,
rng
=
rng
)
g
=
pt
.
random
.
permutation
(
array
,
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
permuted
=
g_fn
()
permuted
=
g_fn
()
...
@@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
...
@@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
def
test_random_geometric
():
def
test_random_geometric
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
p
=
np
.
array
([
0.3
,
0.7
])
p
=
np
.
array
([
0.3
,
0.7
])
g
=
pt
.
random
.
geometric
(
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g
=
pt
.
random
.
geometric
(
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
([],
g
,
mode
=
jax_mode
)
...
@@ -674,7 +654,7 @@ def test_random_geometric():
...
@@ -674,7 +654,7 @@ def test_random_geometric():
def
test_negative_binomial
():
def
test_negative_binomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
n
=
np
.
array
([
10
,
40
])
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([
0.3
,
0.7
])
p
=
np
.
array
([
0.3
,
0.7
])
g
=
pt
.
random
.
negative_binomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g
=
pt
.
random
.
negative_binomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
...
@@ -688,7 +668,7 @@ def test_negative_binomial():
...
@@ -688,7 +668,7 @@ def test_negative_binomial():
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"Binomial dispatch requires numpyro"
)
@pytest.mark.skipif
(
not
numpyro_available
,
reason
=
"Binomial dispatch requires numpyro"
)
def
test_binomial
():
def
test_binomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
n
=
np
.
array
([
10
,
40
])
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([
0.3
,
0.7
])
p
=
np
.
array
([
0.3
,
0.7
])
g
=
pt
.
random
.
binomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g
=
pt
.
random
.
binomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
...
@@ -702,7 +682,7 @@ def test_binomial():
...
@@ -702,7 +682,7 @@ def test_binomial():
not
numpyro_available
,
reason
=
"BetaBinomial dispatch requires numpyro"
not
numpyro_available
,
reason
=
"BetaBinomial dispatch requires numpyro"
)
)
def
test_beta_binomial
():
def
test_beta_binomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
n
=
np
.
array
([
10
,
40
])
n
=
np
.
array
([
10
,
40
])
a
=
np
.
array
([
1.5
,
13
])
a
=
np
.
array
([
1.5
,
13
])
b
=
np
.
array
([
0.5
,
9
])
b
=
np
.
array
([
0.5
,
9
])
...
@@ -721,7 +701,7 @@ def test_beta_binomial():
...
@@ -721,7 +701,7 @@ def test_beta_binomial():
not
numpyro_available
,
reason
=
"Multinomial dispatch requires numpyro"
not
numpyro_available
,
reason
=
"Multinomial dispatch requires numpyro"
)
)
def
test_multinomial
():
def
test_multinomial
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
n
=
np
.
array
([
10
,
40
])
n
=
np
.
array
([
10
,
40
])
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
p
=
np
.
array
([[
0.3
,
0.7
,
0.0
],
[
0.1
,
0.4
,
0.5
]])
g
=
pt
.
random
.
multinomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g
=
pt
.
random
.
multinomial
(
n
,
p
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
...
@@ -737,7 +717,7 @@ def test_multinomial():
...
@@ -737,7 +717,7 @@ def test_multinomial():
def
test_vonmises_mu_outside_circle
():
def
test_vonmises_mu_outside_circle
():
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# We test that the random draws from the JAX dispatch work as expected in these cases
# We test that the random draws from the JAX dispatch work as expected in these cases
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
mu
=
np
.
array
([
-
30
,
40
])
mu
=
np
.
array
([
-
30
,
40
])
kappa
=
np
.
array
([
100
,
10
])
kappa
=
np
.
array
([
100
,
10
])
g
=
pt
.
random
.
vonmises
(
mu
,
kappa
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
g
=
pt
.
random
.
vonmises
(
mu
,
kappa
,
size
=
(
10
_000
,
2
),
rng
=
rng
)
...
@@ -781,7 +761,7 @@ def test_random_unimplemented():
...
@@ -781,7 +761,7 @@ def test_random_unimplemented():
return
0
return
0
nonexistentrv
=
NonExistentRV
()
nonexistentrv
=
NonExistentRV
()
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
out
=
nonexistentrv
(
rng
=
rng
)
out
=
nonexistentrv
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
...
@@ -816,7 +796,7 @@ def test_random_custom_implementation():
...
@@ -816,7 +796,7 @@ def test_random_custom_implementation():
return
sample_fn
return
sample_fn
nonexistentrv
=
CustomRV
()
nonexistentrv
=
CustomRV
()
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
out
=
nonexistentrv
(
rng
=
rng
)
out
=
nonexistentrv
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
with
pytest
.
warns
(
with
pytest
.
warns
(
...
@@ -836,7 +816,7 @@ def test_random_concrete_shape():
...
@@ -836,7 +816,7 @@ def test_random_concrete_shape():
`size` parameter satisfies either of these criteria.
`size` parameter satisfies either of these criteria.
"""
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
x_pt
=
pt
.
dmatrix
()
x_pt
=
pt
.
dmatrix
()
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
x_pt
.
shape
,
rng
=
rng
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
x_pt
.
shape
,
rng
=
rng
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
...
@@ -844,7 +824,7 @@ def test_random_concrete_shape():
...
@@ -844,7 +824,7 @@ def test_random_concrete_shape():
def
test_random_concrete_shape_from_param
():
def
test_random_concrete_shape_from_param
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
x_pt
=
pt
.
dmatrix
()
x_pt
=
pt
.
dmatrix
()
out
=
pt
.
random
.
normal
(
x_pt
,
1
,
rng
=
rng
)
out
=
pt
.
random
.
normal
(
x_pt
,
1
,
rng
=
rng
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
...
@@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor():
...
@@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor():
slight improvement over their API.
slight improvement over their API.
"""
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
x_pt
=
pt
.
dmatrix
()
x_pt
=
pt
.
dmatrix
()
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
x_pt
.
shape
[
1
],
rng
=
rng
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
x_pt
.
shape
[
1
],
rng
=
rng
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
...
@@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple():
...
@@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple():
`jax_size_parameter_as_tuple` rewrite.
`jax_size_parameter_as_tuple` rewrite.
"""
"""
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
x_pt
=
pt
.
dmatrix
()
x_pt
=
pt
.
dmatrix
()
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
(
x_pt
.
shape
[
0
],),
rng
=
rng
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
(
x_pt
.
shape
[
0
],),
rng
=
rng
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
jax_fn
=
compile_random_function
([
x_pt
],
out
,
mode
=
jax_mode
)
...
@@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple():
...
@@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple():
reason
=
"`size_pt` should be specified as a static argument"
,
strict
=
True
reason
=
"`size_pt` should be specified as a static argument"
,
strict
=
True
)
)
def
test_random_concrete_shape_graph_input
():
def
test_random_concrete_shape_graph_input
():
rng
=
shared
(
np
.
random
.
RandomState
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
size_pt
=
pt
.
scalar
()
size_pt
=
pt
.
scalar
()
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
size_pt
,
rng
=
rng
)
out
=
pt
.
random
.
normal
(
0
,
1
,
size
=
size_pt
,
rng
=
rng
)
jax_fn
=
compile_random_function
([
size_pt
],
out
,
mode
=
jax_mode
)
jax_fn
=
compile_random_function
([
size_pt
],
out
,
mode
=
jax_mode
)
...
...
tests/scan/test_basic.py
浏览文件 @
5d4b0c4b
...
@@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct):
...
@@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct):
class
TestScan
:
class
TestScan
:
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rng_type"
,
"rng_type"
,
[
[
np
.
random
.
default_rng
],
np
.
random
.
default_rng
,
np
.
random
.
RandomState
,
],
)
)
def
test_inner_graph_cloning
(
self
,
rng_type
):
def
test_inner_graph_cloning
(
self
,
rng_type
):
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
...
...
tests/tensor/random/test_basic.py
浏览文件 @
5d4b0c4b
...
@@ -51,7 +51,6 @@ from pytensor.tensor.random.basic import (
...
@@ -51,7 +51,6 @@ from pytensor.tensor.random.basic import (
pareto
,
pareto
,
permutation
,
permutation
,
poisson
,
poisson
,
randint
,
rayleigh
,
rayleigh
,
standard_normal
,
standard_normal
,
t
,
t
,
...
@@ -1355,27 +1354,6 @@ def test_categorical_basic():
...
@@ -1355,27 +1354,6 @@ def test_categorical_basic():
categorical
.
rng_fn
(
rng
,
p
[
None
],
size
=
(
3
,))
categorical
.
rng_fn
(
rng
,
p
[
None
],
size
=
(
3
,))
def
test_randint_samples
():
with
pytest
.
raises
(
TypeError
):
randint
(
10
,
rng
=
shared
(
np
.
random
.
default_rng
()))
rng
=
np
.
random
.
RandomState
(
2313
)
compare_sample_values
(
randint
,
10
,
None
,
rng
=
rng
)
compare_sample_values
(
randint
,
0
,
1
,
rng
=
rng
)
compare_sample_values
(
randint
,
0
,
1
,
size
=
[
3
],
rng
=
rng
)
compare_sample_values
(
randint
,
[
0
,
1
,
2
],
5
,
rng
=
rng
)
compare_sample_values
(
randint
,
[
0
,
1
,
2
],
5
,
size
=
[
3
,
3
],
rng
=
rng
)
compare_sample_values
(
randint
,
[
0
],
[
5
],
size
=
[
1
],
rng
=
rng
)
compare_sample_values
(
randint
,
pt
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
[
1
],
rng
=
rng
)
compare_sample_values
(
randint
,
pt
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
pt
.
as_tensor_variable
([
1
]),
rng
=
rng
,
)
def
test_integers_samples
():
def
test_integers_samples
():
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
integers
(
10
,
rng
=
shared
(
np
.
random
.
RandomState
()))
integers
(
10
,
rng
=
shared
(
np
.
random
.
RandomState
()))
...
...
tests/tensor/random/test_op.py
浏览文件 @
5d4b0c4b
...
@@ -8,7 +8,7 @@ from pytensor.raise_op import Assert
...
@@ -8,7 +8,7 @@ from pytensor.raise_op import Assert
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.basic
import
NormalRV
from
pytensor.tensor.random.basic
import
NormalRV
from
pytensor.tensor.random.op
import
Random
State
,
Random
Variable
,
default_rng
from
pytensor.tensor.random.op
import
RandomVariable
,
default_rng
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.type
import
iscalar
,
tensor
from
pytensor.tensor.type
import
iscalar
,
tensor
...
@@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags):
...
@@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"seed, maker_op, numpy_res"
,
"seed, maker_op, numpy_res"
,
[
[
(
3
,
RandomState
,
np
.
random
.
RandomState
(
3
)),
(
3
,
default_rng
,
np
.
random
.
default_rng
(
3
)),
(
3
,
default_rng
,
np
.
random
.
default_rng
(
3
)),
],
],
)
)
...
@@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags):
...
@@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags):
# Testing the initialization when seed=None
# Testing the initialization when seed=None
# Since internal states randomly generated,
# Since internal states randomly generated,
# we just check the output classes
# we just check the output classes
z
=
function
(
inputs
=
[],
outputs
=
[
RandomState
()])()
aes_res
=
z
[
0
]
assert
isinstance
(
aes_res
,
np
.
random
.
RandomState
)
z
=
function
(
inputs
=
[],
outputs
=
[
default_rng
()])()
z
=
function
(
inputs
=
[],
outputs
=
[
default_rng
()])()
aes_res
=
z
[
0
]
aes_res
=
z
[
0
]
assert
isinstance
(
aes_res
,
np
.
random
.
Generator
)
assert
isinstance
(
aes_res
,
np
.
random
.
Generator
)
...
...
tests/tensor/random/test_type.py
浏览文件 @
5d4b0c4b
...
@@ -7,9 +7,7 @@ from pytensor import shared
...
@@ -7,9 +7,7 @@ from pytensor import shared
from
pytensor.compile.ops
import
ViewOp
from
pytensor.compile.ops
import
ViewOp
from
pytensor.tensor.random.type
import
(
from
pytensor.tensor.random.type
import
(
RandomGeneratorType
,
RandomGeneratorType
,
RandomStateType
,
random_generator_type
,
random_generator_type
,
random_state_type
,
)
)
...
@@ -28,101 +26,9 @@ def test_view_op_c_code():
...
@@ -28,101 +26,9 @@ def test_view_op_c_code():
# rng_view,
# rng_view,
# mode=Mode(optimizer=None, linker=CLinker()),
# mode=Mode(optimizer=None, linker=CLinker()),
# )
# )
assert
ViewOp
.
c_code_and_version
[
RandomStateType
]
assert
ViewOp
.
c_code_and_version
[
RandomGeneratorType
]
assert
ViewOp
.
c_code_and_version
[
RandomGeneratorType
]
class
TestRandomStateType
:
def
test_pickle
(
self
):
rng_r
=
random_state_type
()
rng_pkl
=
pickle
.
dumps
(
rng_r
)
rng_unpkl
=
pickle
.
loads
(
rng_pkl
)
assert
rng_r
!=
rng_unpkl
assert
rng_r
.
type
==
rng_unpkl
.
type
assert
hash
(
rng_r
.
type
)
==
hash
(
rng_unpkl
.
type
)
def
test_repr
(
self
):
assert
repr
(
random_state_type
)
==
"RandomStateType"
def
test_filter
(
self
):
rng_type
=
random_state_type
rng
=
np
.
random
.
RandomState
()
assert
rng_type
.
filter
(
rng
)
is
rng
with
pytest
.
raises
(
TypeError
):
rng_type
.
filter
(
1
)
rng_dict
=
rng
.
get_state
(
legacy
=
False
)
assert
rng_type
.
is_valid_value
(
rng_dict
)
is
False
assert
rng_type
.
is_valid_value
(
rng_dict
,
strict
=
False
)
rng_dict
[
"state"
]
=
{}
assert
rng_type
.
is_valid_value
(
rng_dict
,
strict
=
False
)
is
False
rng_dict
=
{}
assert
rng_type
.
is_valid_value
(
rng_dict
,
strict
=
False
)
is
False
def
test_values_eq
(
self
):
rng_type
=
random_state_type
rng_a
=
np
.
random
.
RandomState
(
12
)
rng_b
=
np
.
random
.
RandomState
(
12
)
rng_c
=
np
.
random
.
RandomState
(
123
)
bg
=
np
.
random
.
PCG64
()
rng_d
=
np
.
random
.
RandomState
(
bg
)
rng_e
=
np
.
random
.
RandomState
(
bg
)
bg_2
=
np
.
random
.
Philox
()
rng_f
=
np
.
random
.
RandomState
(
bg_2
)
rng_g
=
np
.
random
.
RandomState
(
bg_2
)
assert
rng_type
.
values_eq
(
rng_a
,
rng_b
)
assert
not
rng_type
.
values_eq
(
rng_a
,
rng_c
)
assert
not
rng_type
.
values_eq
(
rng_a
,
rng_d
)
assert
not
rng_type
.
values_eq
(
rng_d
,
rng_a
)
assert
not
rng_type
.
values_eq
(
rng_a
,
rng_d
)
assert
rng_type
.
values_eq
(
rng_d
,
rng_e
)
assert
rng_type
.
values_eq
(
rng_f
,
rng_g
)
assert
not
rng_type
.
values_eq
(
rng_g
,
rng_a
)
assert
not
rng_type
.
values_eq
(
rng_e
,
rng_g
)
def
test_may_share_memory
(
self
):
bg1
=
np
.
random
.
MT19937
()
bg2
=
np
.
random
.
MT19937
()
rng_a
=
np
.
random
.
RandomState
(
bg1
)
rng_b
=
np
.
random
.
RandomState
(
bg2
)
rng_var_a
=
shared
(
rng_a
,
borrow
=
True
)
rng_var_b
=
shared
(
rng_b
,
borrow
=
True
)
assert
(
random_state_type
.
may_share_memory
(
rng_var_a
.
get_value
(
borrow
=
True
),
rng_var_b
.
get_value
(
borrow
=
True
)
)
is
False
)
rng_c
=
np
.
random
.
RandomState
(
bg2
)
rng_var_c
=
shared
(
rng_c
,
borrow
=
True
)
assert
(
random_state_type
.
may_share_memory
(
rng_var_b
.
get_value
(
borrow
=
True
),
rng_var_c
.
get_value
(
borrow
=
True
)
)
is
True
)
class
TestRandomGeneratorType
:
class
TestRandomGeneratorType
:
def
test_pickle
(
self
):
def
test_pickle
(
self
):
rng_r
=
random_generator_type
()
rng_r
=
random_generator_type
()
...
@@ -200,7 +106,7 @@ class TestRandomGeneratorType:
...
@@ -200,7 +106,7 @@ class TestRandomGeneratorType:
rng_var_b
=
shared
(
rng_b
,
borrow
=
True
)
rng_var_b
=
shared
(
rng_b
,
borrow
=
True
)
assert
(
assert
(
random_
state
_type
.
may_share_memory
(
random_
generator
_type
.
may_share_memory
(
rng_var_a
.
get_value
(
borrow
=
True
),
rng_var_b
.
get_value
(
borrow
=
True
)
rng_var_a
.
get_value
(
borrow
=
True
),
rng_var_b
.
get_value
(
borrow
=
True
)
)
)
is
False
is
False
...
@@ -210,7 +116,7 @@ class TestRandomGeneratorType:
...
@@ -210,7 +116,7 @@ class TestRandomGeneratorType:
rng_var_c
=
shared
(
rng_c
,
borrow
=
True
)
rng_var_c
=
shared
(
rng_c
,
borrow
=
True
)
assert
(
assert
(
random_
state
_type
.
may_share_memory
(
random_
generator
_type
.
may_share_memory
(
rng_var_b
.
get_value
(
borrow
=
True
),
rng_var_c
.
get_value
(
borrow
=
True
)
rng_var_b
.
get_value
(
borrow
=
True
),
rng_var_c
.
get_value
(
borrow
=
True
)
)
)
is
True
is
True
...
...
tests/tensor/random/test_utils.py
浏览文件 @
5d4b0c4b
...
@@ -101,7 +101,7 @@ class TestSharedRandomStream:
...
@@ -101,7 +101,7 @@ class TestSharedRandomStream:
assert
np
.
all
(
g
()
==
g
())
assert
np
.
all
(
g
()
==
g
())
assert
np
.
all
(
abs
(
nearly_zeros
())
<
1e-5
)
assert
np
.
all
(
abs
(
nearly_zeros
())
<
1e-5
)
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_basics
(
self
,
rng_ctor
):
def
test_basics
(
self
,
rng_ctor
):
random
=
RandomStream
(
seed
=
utt
.
fetch_seed
(),
rng_ctor
=
rng_ctor
)
random
=
RandomStream
(
seed
=
utt
.
fetch_seed
(),
rng_ctor
=
rng_ctor
)
...
@@ -132,7 +132,7 @@ class TestSharedRandomStream:
...
@@ -132,7 +132,7 @@ class TestSharedRandomStream:
assert
np
.
allclose
(
fn_val0
,
numpy_val0
)
assert
np
.
allclose
(
fn_val0
,
numpy_val0
)
assert
np
.
allclose
(
fn_val1
,
numpy_val1
)
assert
np
.
allclose
(
fn_val1
,
numpy_val1
)
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_seed
(
self
,
rng_ctor
):
def
test_seed
(
self
,
rng_ctor
):
init_seed
=
234
init_seed
=
234
random
=
RandomStream
(
init_seed
,
rng_ctor
=
rng_ctor
)
random
=
RandomStream
(
init_seed
,
rng_ctor
=
rng_ctor
)
...
@@ -176,7 +176,7 @@ class TestSharedRandomStream:
...
@@ -176,7 +176,7 @@ class TestSharedRandomStream:
assert
random_state
[
"bit_generator"
]
==
ref_state
[
"bit_generator"
]
assert
random_state
[
"bit_generator"
]
==
ref_state
[
"bit_generator"
]
assert
random_state
[
"state"
]
==
ref_state
[
"state"
]
assert
random_state
[
"state"
]
==
ref_state
[
"state"
]
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_uniform
(
self
,
rng_ctor
):
def
test_uniform
(
self
,
rng_ctor
):
# Test that RandomStream.uniform generates the same results as numpy
# Test that RandomStream.uniform generates the same results as numpy
# Check over two calls to see if the random state is correctly updated.
# Check over two calls to see if the random state is correctly updated.
...
@@ -195,7 +195,7 @@ class TestSharedRandomStream:
...
@@ -195,7 +195,7 @@ class TestSharedRandomStream:
assert
np
.
allclose
(
fn_val0
,
numpy_val0
)
assert
np
.
allclose
(
fn_val0
,
numpy_val0
)
assert
np
.
allclose
(
fn_val1
,
numpy_val1
)
assert
np
.
allclose
(
fn_val1
,
numpy_val1
)
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_default_updates
(
self
,
rng_ctor
):
def
test_default_updates
(
self
,
rng_ctor
):
# Basic case: default_updates
# Basic case: default_updates
random_a
=
RandomStream
(
utt
.
fetch_seed
(),
rng_ctor
=
rng_ctor
)
random_a
=
RandomStream
(
utt
.
fetch_seed
(),
rng_ctor
=
rng_ctor
)
...
@@ -244,7 +244,7 @@ class TestSharedRandomStream:
...
@@ -244,7 +244,7 @@ class TestSharedRandomStream:
assert
np
.
all
(
fn_e_val0
==
fn_a_val0
)
assert
np
.
all
(
fn_e_val0
==
fn_a_val0
)
assert
np
.
all
(
fn_e_val1
==
fn_e_val0
)
assert
np
.
all
(
fn_e_val1
==
fn_e_val0
)
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_multiple_rng_aliasing
(
self
,
rng_ctor
):
def
test_multiple_rng_aliasing
(
self
,
rng_ctor
):
# Test that when we have multiple random number generators, we do not alias
# Test that when we have multiple random number generators, we do not alias
# the state_updates member. `state_updates` can be useful when attempting to
# the state_updates member. `state_updates` can be useful when attempting to
...
@@ -257,7 +257,7 @@ class TestSharedRandomStream:
...
@@ -257,7 +257,7 @@ class TestSharedRandomStream:
assert
rng1
.
state_updates
is
not
rng2
.
state_updates
assert
rng1
.
state_updates
is
not
rng2
.
state_updates
assert
rng1
.
gen_seedgen
is
not
rng2
.
gen_seedgen
assert
rng1
.
gen_seedgen
is
not
rng2
.
gen_seedgen
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_random_state_transfer
(
self
,
rng_ctor
):
def
test_random_state_transfer
(
self
,
rng_ctor
):
# Test that random state can be transferred from one pytensor graph to another.
# Test that random state can be transferred from one pytensor graph to another.
...
...
tests/tensor/random/test_var.py
浏览文件 @
5d4b0c4b
...
@@ -4,9 +4,7 @@ import pytest
...
@@ -4,9 +4,7 @@ import pytest
from
pytensor
import
shared
from
pytensor
import
shared
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
default_rng
(
123
)])
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_GeneratorSharedVariable
(
rng
):
def
test_GeneratorSharedVariable
(
rng
):
s_rng_default
=
shared
(
rng
)
s_rng_default
=
shared
(
rng
)
s_rng_True
=
shared
(
rng
,
borrow
=
True
)
s_rng_True
=
shared
(
rng
,
borrow
=
True
)
...
@@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng):
...
@@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng):
assert
v
==
v0
==
v1
assert
v
==
v0
==
v1
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
default_rng
(
123
)])
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_get_value_borrow
(
rng
):
def
test_get_value_borrow
(
rng
):
s_rng
=
shared
(
rng
)
s_rng
=
shared
(
rng
)
...
@@ -55,9 +51,7 @@ def test_get_value_borrow(rng):
...
@@ -55,9 +51,7 @@ def test_get_value_borrow(rng):
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
default_rng
(
123
)])
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_get_value_internal_type
(
rng
):
def
test_get_value_internal_type
(
rng
):
s_rng
=
shared
(
rng
)
s_rng
=
shared
(
rng
)
...
@@ -81,7 +75,7 @@ def test_get_value_internal_type(rng):
...
@@ -81,7 +75,7 @@ def test_get_value_internal_type(rng):
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
default_rng
])
def
test_set_value_borrow
(
rng_ctor
):
def
test_set_value_borrow
(
rng_ctor
):
s_rng
=
shared
(
rng_ctor
(
123
))
s_rng
=
shared
(
rng_ctor
(
123
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论