Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e0eea331
提交
e0eea331
authored
6月 02, 2021
作者:
kc611
提交者:
Brandon T. Willard
6月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add NumPy Generator support for RandomVariables
上级
adf83fa5
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
415 行增加
和
96 行删除
+415
-96
monitormode.py
aesara/compile/monitormode.py
+1
-1
nanguardmode.py
aesara/compile/nanguardmode.py
+1
-1
gradient.py
aesara/gradient.py
+1
-1
rng_mrg.py
aesara/sandbox/rng_mrg.py
+1
-1
basic.py
aesara/tensor/basic.py
+3
-1
basic.py
aesara/tensor/random/basic.py
+40
-2
op.py
aesara/tensor/random/op.py
+9
-7
type.py
aesara/tensor/random/type.py
+124
-20
utils.py
aesara/tensor/random/utils.py
+19
-15
var.py
aesara/tensor/random/var.py
+19
-9
setup.py
setup.py
+1
-1
test_basic.py
tests/tensor/random/test_basic.py
+43
-12
test_type.py
tests/tensor/random/test_type.py
+114
-6
test_utils.py
tests/tensor/random/test_utils.py
+0
-0
test_var.py
tests/tensor/random/test_var.py
+39
-19
没有找到文件。
aesara/compile/monitormode.py
浏览文件 @
e0eea331
...
...
@@ -105,7 +105,7 @@ def detect_nan(fgraph, i, node, fn):
for
output
in
fn
.
outputs
:
if
(
not
isinstance
(
output
[
0
],
np
.
random
.
RandomState
)
not
isinstance
(
output
[
0
],
(
np
.
random
.
RandomState
,
np
.
random
.
Generator
)
)
and
np
.
isnan
(
output
[
0
])
.
any
()
):
print
(
"*** NaN detected ***"
)
...
...
aesara/compile/nanguardmode.py
浏览文件 @
e0eea331
...
...
@@ -44,7 +44,7 @@ def _is_numeric_value(arr, var):
"""
if
isinstance
(
arr
,
aesara
.
graph
.
type
.
_cdata_type
):
return
False
elif
isinstance
(
arr
,
np
.
random
.
mtrand
.
RandomState
):
elif
isinstance
(
arr
,
(
np
.
random
.
mtrand
.
RandomState
,
np
.
random
.
Generator
)
):
return
False
elif
var
and
getattr
(
var
.
tag
,
"is_rng"
,
False
):
return
False
...
...
aesara/gradient.py
浏览文件 @
e0eea331
...
...
@@ -1841,7 +1841,7 @@ def verify_grad(
# random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient
def
random_projection
():
plain
=
rng
.
rand
(
*
o_fn_out
.
shape
)
+
0.5
plain
=
rng
.
rand
om
(
o_fn_out
.
shape
)
+
0.5
if
cast_to_output_type
and
o_output
.
dtype
==
"float32"
:
return
np
.
array
(
plain
,
o_output
.
dtype
)
return
plain
...
...
aesara/sandbox/rng_mrg.py
浏览文件 @
e0eea331
...
...
@@ -736,7 +736,7 @@ class MRG_RandomStream:
def
set_rstate
(
self
,
seed
):
# TODO : need description for method, parameter
if
isinstance
(
seed
,
int
):
if
isinstance
(
seed
,
(
int
,
np
.
int32
,
np
.
int64
)
):
if
seed
==
0
:
raise
ValueError
(
"seed should not be 0"
,
seed
)
elif
seed
>=
M2
:
...
...
aesara/tensor/basic.py
浏览文件 @
e0eea331
...
...
@@ -72,7 +72,9 @@ def check_equal_numpy(x, y):
"""
if
isinstance
(
x
,
np
.
ndarray
)
and
isinstance
(
y
,
np
.
ndarray
):
return
x
.
dtype
==
y
.
dtype
and
x
.
shape
==
y
.
shape
and
np
.
all
(
abs
(
x
-
y
)
<
1e-10
)
elif
isinstance
(
x
,
np
.
random
.
RandomState
)
and
isinstance
(
y
,
np
.
random
.
RandomState
):
elif
isinstance
(
x
,
(
np
.
random
.
Generator
,
np
.
random
.
RandomState
))
and
isinstance
(
y
,
(
np
.
random
.
Generator
,
np
.
random
.
RandomState
)
):
return
builtins
.
all
(
np
.
all
(
a
==
b
)
for
a
,
b
in
zip
(
x
.
__getstate__
(),
y
.
__getstate__
())
)
...
...
aesara/tensor/random/basic.py
浏览文件 @
e0eea331
...
...
@@ -6,7 +6,12 @@ import scipy.stats as stats
import
aesara
from
aesara.tensor.basic
import
as_tensor_variable
from
aesara.tensor.random.op
import
RandomVariable
,
default_shape_from_params
from
aesara.tensor.random.type
import
RandomGeneratorType
,
RandomStateType
from
aesara.tensor.random.utils
import
broadcast_params
from
aesara.tensor.random.var
import
(
RandomGeneratorSharedVariable
,
RandomStateSharedVariable
,
)
try
:
...
...
@@ -165,7 +170,7 @@ class GumbelRV(RandomVariable):
@classmethod
def
rng_fn
(
cls
,
rng
:
np
.
random
.
RandomState
,
rng
:
Union
[
np
.
random
.
Generator
,
np
.
random
.
RandomState
]
,
loc
:
Union
[
np
.
ndarray
,
float
],
scale
:
Union
[
np
.
ndarray
,
float
],
size
:
Optional
[
Union
[
List
[
int
],
int
]],
...
...
@@ -590,7 +595,8 @@ class PolyaGammaRV(RandomVariable):
@classmethod
def
rng_fn
(
cls
,
rng
,
b
,
c
,
size
):
pg
=
PyPolyaGamma
(
rng
.
randint
(
2
**
16
))
rand_method
=
rng
.
integers
if
hasattr
(
rng
,
"integers"
)
else
rng
.
randint
pg
=
PyPolyaGamma
(
rand_method
(
2
**
16
))
if
not
size
and
b
.
shape
==
c
.
shape
==
():
return
pg
.
pgdraw
(
b
,
c
)
...
...
@@ -627,10 +633,41 @@ class RandIntRV(RandomVariable):
low
,
high
=
0
,
low
return
super
()
.
__call__
(
low
,
high
,
size
=
size
,
**
kwargs
)
def
make_node
(
self
,
rng
,
*
args
,
**
kwargs
):
if
not
isinstance
(
getattr
(
rng
,
"type"
,
None
),
(
RandomStateType
,
RandomStateSharedVariable
)
):
raise
TypeError
(
"`randint` is only available for `RandomStateType`s"
)
return
super
()
.
make_node
(
rng
,
*
args
,
**
kwargs
)
randint
=
RandIntRV
()
class
IntegersRV
(
RandomVariable
):
name
=
"integers"
ndim_supp
=
0
ndims_params
=
[
0
,
0
]
dtype
=
"int64"
_print_name
=
(
"integers"
,
"
\\
operatorname{integers}"
)
def
__call__
(
self
,
low
,
high
=
None
,
size
=
None
,
**
kwargs
):
if
high
is
None
:
low
,
high
=
0
,
low
return
super
()
.
__call__
(
low
,
high
,
size
=
size
,
**
kwargs
)
def
make_node
(
self
,
rng
,
*
args
,
**
kwargs
):
if
not
isinstance
(
getattr
(
rng
,
"type"
,
None
),
(
RandomGeneratorType
,
RandomGeneratorSharedVariable
),
):
raise
TypeError
(
"`integers` is only available for `RandomGeneratorType`s"
)
return
super
()
.
make_node
(
rng
,
*
args
,
**
kwargs
)
integers
=
IntegersRV
()
class
ChoiceRV
(
RandomVariable
):
name
=
"choice"
ndim_supp
=
0
...
...
@@ -698,6 +735,7 @@ permutation = PermutationRV()
__all__
=
[
"permutation"
,
"choice"
,
"integers"
,
"randint"
,
"categorical"
,
"multinomial"
,
...
...
aesara/tensor/random/op.py
浏览文件 @
e0eea331
...
...
@@ -18,7 +18,7 @@ from aesara.tensor.basic import (
)
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.random.type
import
Random
State
Type
from
aesara.tensor.random.type
import
RandomType
from
aesara.tensor.random.utils
import
normalize_size_param
,
params_broadcast_shapes
from
aesara.tensor.type
import
TensorType
,
all_dtypes
...
...
@@ -158,7 +158,7 @@ class RandomVariable(Op):
def
rng_fn
(
self
,
rng
,
*
args
,
**
kwargs
):
"""Sample a numeric random variate."""
return
getattr
(
np
.
random
.
RandomState
,
self
.
name
)(
rng
,
*
args
,
**
kwargs
)
return
getattr
(
rng
,
self
.
name
)(
*
args
,
**
kwargs
)
def
__str__
(
self
):
props_str
=
", "
.
join
((
f
"{getattr(self, prop)}"
for
prop
in
self
.
__props__
[
1
:]))
...
...
@@ -336,8 +336,8 @@ class RandomVariable(Op):
Parameters
----------
rng: RandomStateType
Existing Aesara `RandomState` object to be used. Creates a
rng: Random
GeneratorType or Random
StateType
Existing Aesara `
Generator` or `
RandomState` object to be used. Creates a
new one, if `None`.
size: int or Sequence
Numpy-like size of the output (i.e. replications).
...
...
@@ -363,9 +363,11 @@ class RandomVariable(Op):
)
if
rng
is
None
:
rng
=
aesara
.
shared
(
np
.
random
.
RandomState
())
elif
not
isinstance
(
rng
.
type
,
RandomStateType
):
raise
TypeError
(
"The type of rng should be an instance of RandomStateType"
)
rng
=
aesara
.
shared
(
np
.
random
.
default_rng
())
elif
not
isinstance
(
rng
.
type
,
RandomType
):
raise
TypeError
(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
bcast
=
self
.
compute_bcast
(
dist_params
,
size
)
dtype
=
self
.
dtype
or
dtype
...
...
aesara/tensor/random/type.py
浏览文件 @
e0eea331
...
...
@@ -6,21 +6,22 @@ import aesara
from
aesara.graph.type
import
Type
class
RandomStateType
(
Type
):
"""A Type wrapper for `numpy.random.RandomState`.
gen_states_keys
=
{
"MT19937"
:
([
"state"
],
[
"key"
,
"pos"
]),
"PCG64"
:
([
"state"
,
"has_uint32"
,
"uinteger"
],
[
"state"
,
"inc"
]),
"Philox"
:
(
[
"state"
,
"buffer"
,
"buffer_pos"
,
"has_uint32"
,
"uinteger"
],
[
"counter"
,
"key"
],
),
"SFC64"
:
([
"state"
,
"has_uint32"
,
"uinteger"
],
[
"state"
]),
}
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the `==` operator. This `Type` exists to provide an equals function
that is used by `DebugMode`.
Also works with a `dict` derived from RandomState.get_state() unless
the `strict` argument is explicitly set to `True`.
# We map bit generators to an integer index so that we can avoid using strings
numpy_bit_gens
=
{
0
:
"MT19937"
,
1
:
"PCG64"
,
2
:
"Philox"
,
3
:
"SFC64"
}
"""
def
__repr__
(
self
):
return
"RandomStateType
"
class
RandomType
(
Type
):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""
"
@classmethod
def
filter
(
cls
,
data
,
strict
=
False
,
allow_downcast
=
None
):
...
...
@@ -29,6 +30,31 @@ class RandomStateType(Type):
else
:
raise
TypeError
()
@staticmethod
def
get_shape_info
(
obj
):
return
obj
.
get_value
(
borrow
=
True
)
@staticmethod
def
may_share_memory
(
a
,
b
):
return
a
.
_bit_generator
is
b
.
_bit_generator
class
RandomStateType
(
RandomType
):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def
__repr__
(
self
):
return
"RandomStateType"
@staticmethod
def
is_valid_value
(
a
,
strict
):
if
isinstance
(
a
,
np
.
random
.
RandomState
):
...
...
@@ -73,18 +99,10 @@ class RandomStateType(Type):
return
_eq
(
sa
,
sb
)
@staticmethod
def
get_shape_info
(
obj
):
return
obj
.
get_value
(
borrow
=
True
)
@staticmethod
def
get_size
(
shape_info
):
return
sys
.
getsizeof
(
shape_info
.
get_state
(
legacy
=
False
))
@staticmethod
def
may_share_memory
(
a
,
b
):
return
a
.
_bit_generator
is
b
.
_bit_generator
# Register `RandomStateType`'s C code for `ViewOp`.
aesara
.
compile
.
register_view_op_c_code
(
...
...
@@ -98,3 +116,89 @@ aesara.compile.register_view_op_c_code(
)
random_state_type
=
RandomStateType
()
class
RandomGeneratorType
(
RandomType
):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
`Generator` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`Generator.__get_state__`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def
__repr__
(
self
):
return
"RandomGeneratorType"
@staticmethod
def
is_valid_value
(
a
,
strict
):
if
isinstance
(
a
,
np
.
random
.
Generator
):
return
True
if
not
strict
and
isinstance
(
a
,
dict
):
if
"bit_generator"
not
in
a
:
return
False
else
:
bit_gen_key
=
a
[
"bit_generator"
]
if
hasattr
(
bit_gen_key
,
"_value"
):
bit_gen_key
=
int
(
bit_gen_key
.
_value
)
bit_gen_key
=
numpy_bit_gens
[
bit_gen_key
]
gen_keys
,
state_keys
=
gen_states_keys
[
bit_gen_key
]
for
key
in
gen_keys
:
if
key
not
in
a
:
return
False
for
key
in
state_keys
:
if
key
not
in
a
[
"state"
]:
return
False
return
True
return
False
@staticmethod
def
values_eq
(
a
,
b
):
sa
=
a
if
isinstance
(
a
,
dict
)
else
a
.
__getstate__
()
sb
=
b
if
isinstance
(
b
,
dict
)
else
b
.
__getstate__
()
def
_eq
(
sa
,
sb
):
for
key
in
sa
:
if
isinstance
(
sa
[
key
],
dict
):
if
not
_eq
(
sa
[
key
],
sb
[
key
]):
return
False
elif
isinstance
(
sa
[
key
],
np
.
ndarray
):
if
not
np
.
array_equal
(
sa
[
key
],
sb
[
key
]):
return
False
else
:
if
sa
[
key
]
!=
sb
[
key
]:
return
False
return
True
return
_eq
(
sa
,
sb
)
@staticmethod
def
get_size
(
shape_info
):
state
=
shape_info
.
__getstate__
()
return
sys
.
getsizeof
(
state
)
# Register `RandomGeneratorType`'s C code for `ViewOp`.
aesara
.
compile
.
register_view_op_c_code
(
RandomGeneratorType
,
"""
Py_XDECREF(
%(oname)
s);
%(oname)
s =
%(iname)
s;
Py_XINCREF(
%(oname)
s);
"""
,
1
,
)
random_generator_type
=
RandomGeneratorType
()
aesara/tensor/random/utils.py
浏览文件 @
e0eea331
...
...
@@ -117,25 +117,28 @@ def normalize_size_param(size):
class
RandomStream
:
"""Module component with similar interface to `numpy.random.
RandomState
`.
"""Module component with similar interface to `numpy.random.
Generator
`.
Attributes
----------
seed: None or int
A default seed to initialize the
RandomState
instances after build.
A default seed to initialize the
`Generator`
instances after build.
state_updates: list
A list of pairs of the form `
(input_r, output_r)
`. This will be
A list of pairs of the form `
`(input_r, output_r)`
`. This will be
over-ridden by the module instance to contain stream generators.
default_instance_seed: int
Instance variable should take None or integer value. Used to seed the
random number generator that provides seeds for member streams.
gen_seedgen: numpy.random.
RandomState
`
RandomState
` instance that `RandomStream.gen` uses to seed new
gen_seedgen: numpy.random.
Generator
`
Generator
` instance that `RandomStream.gen` uses to seed new
streams.
rng_ctor: type
Constructor used to create the underlying RNG objects. The default
is `np.random.default_rng`.
"""
def
__init__
(
self
,
seed
=
None
,
namespace
=
None
):
def
__init__
(
self
,
seed
=
None
,
namespace
=
None
,
rng_ctor
=
np
.
random
.
default_rng
):
if
namespace
is
None
:
from
aesara.tensor.random
import
basic
# pylint: disable=import-self
...
...
@@ -145,7 +148,8 @@ class RandomStream:
self
.
default_instance_seed
=
seed
self
.
state_updates
=
[]
self
.
gen_seedgen
=
np
.
random
.
RandomState
(
seed
)
self
.
gen_seedgen
=
np
.
random
.
default_rng
(
seed
)
self
.
rng_ctor
=
rng_ctor
def
__getattr__
(
self
,
obj
):
...
...
@@ -191,11 +195,11 @@ class RandomStream:
if
seed
is
None
:
seed
=
self
.
default_instance_seed
self
.
gen_seedgen
.
seed
(
seed
)
self
.
gen_seedgen
=
np
.
random
.
default_rng
(
seed
)
for
old_r
,
new_r
in
self
.
state_updates
:
old_r_seed
=
self
.
gen_seedgen
.
randint
(
2
**
30
)
old_r
.
set_value
(
np
.
random
.
RandomState
(
int
(
old_r_seed
)),
borrow
=
True
)
old_r_seed
=
self
.
gen_seedgen
.
integers
(
2
**
30
)
old_r
.
set_value
(
self
.
rng_ctor
(
int
(
old_r_seed
)),
borrow
=
True
)
def
gen
(
self
,
op
,
*
args
,
**
kwargs
):
"""Create a new random stream in this container.
...
...
@@ -213,18 +217,18 @@ class RandomStream:
-------
TensorVariable
The symbolic random draw part of op()'s return value.
This function stores the updated `Random
State
Type` variable
This function stores the updated `Random
Generator
Type` variable
for use at `build` time.
"""
if
"rng"
in
kwargs
:
raise
Typ
eError
(
"The
rng option cannot be used with a variate in a RandomStream
"
raise
Valu
eError
(
"The
`rng` option cannot be used with a variate in a `RandomStream`
"
)
# Generate a new random state
seed
=
int
(
self
.
gen_seedgen
.
randint
(
2
**
30
))
random_state_variable
=
shared
(
np
.
random
.
RandomState
(
seed
))
seed
=
int
(
self
.
gen_seedgen
.
integers
(
2
**
30
))
random_state_variable
=
shared
(
self
.
rng_ctor
(
seed
))
# Distinguish it from other shared variables (why?)
random_state_variable
.
tag
.
is_rng
=
True
...
...
aesara/tensor/random/var.py
浏览文件 @
e0eea331
...
...
@@ -3,7 +3,7 @@ import copy
import
numpy
as
np
from
aesara.compile.sharedvalue
import
SharedVariable
,
shared_constructor
from
aesara.tensor.random.type
import
random_state_type
from
aesara.tensor.random.type
import
random_
generator_type
,
random_
state_type
class
RandomStateSharedVariable
(
SharedVariable
):
...
...
@@ -11,20 +11,30 @@ class RandomStateSharedVariable(SharedVariable):
return
"RandomStateSharedVariable({})"
.
format
(
repr
(
self
.
container
))
class
RandomGeneratorSharedVariable
(
SharedVariable
):
def
__str__
(
self
):
return
"RandomGeneratorSharedVariable({})"
.
format
(
repr
(
self
.
container
))
@shared_constructor
def
random
state
_constructor
(
def
random
gen
_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
):
"""
SharedVariable Constructor for RandomState.
r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`."""
if
isinstance
(
value
,
np
.
random
.
RandomState
):
rng_sv_type
=
RandomStateSharedVariable
rng_type
=
random_state_type
elif
isinstance
(
value
,
np
.
random
.
Generator
):
rng_sv_type
=
RandomGeneratorSharedVariable
rng_type
=
random_generator_type
else
:
raise
TypeError
()
"""
if
not
isinstance
(
value
,
np
.
random
.
RandomState
):
raise
TypeError
if
not
borrow
:
value
=
copy
.
deepcopy
(
value
)
return
RandomStateSharedVariable
(
type
=
random_state_type
,
return
rng_sv_type
(
type
=
rng_type
,
value
=
value
,
name
=
name
,
strict
=
strict
,
...
...
setup.py
浏览文件 @
e0eea331
...
...
@@ -57,7 +57,7 @@ if __name__ == "__main__":
license
=
LICENSE
,
platforms
=
PLATFORMS
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"tests.*"
]),
install_requires
=
[
"numpy>=1.
9.1
"
,
"scipy>=0.14"
,
"filelock"
],
install_requires
=
[
"numpy>=1.
17.0
"
,
"scipy>=0.14"
,
"filelock"
],
package_data
=
{
""
:
[
"*.txt"
,
...
...
tests/tensor/random/test_basic.py
浏览文件 @
e0eea331
...
...
@@ -29,6 +29,7 @@ from aesara.tensor.random.basic import (
halfcauchy
,
halfnormal
,
hypergeometric
,
integers
,
invgamma
,
laplace
,
logistic
,
...
...
@@ -58,7 +59,7 @@ def set_aesara_flags():
yield
def
rv_numpy_tester
(
rv
,
*
params
,
**
kwargs
):
def
rv_numpy_tester
(
rv
,
*
params
,
rng
=
None
,
**
kwargs
):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
"""
...
...
@@ -70,9 +71,9 @@ def rv_numpy_tester(rv, *params, **kwargs):
if
name
is
None
:
name
=
rv
.
__name__
test_fn
=
getattr
(
np
.
random
,
name
)
test_fn
=
getattr
(
rng
or
np
.
random
,
name
)
aesara_res
=
rv
(
*
params
,
**
kwargs
)
aesara_res
=
rv
(
*
params
,
rng
=
shared
(
rng
)
if
rng
else
None
,
**
kwargs
)
param_vals
=
[
get_test_value
(
p
)
if
isinstance
(
p
,
Variable
)
else
p
for
p
in
params
]
kwargs_vals
=
{
...
...
@@ -738,17 +739,47 @@ def test_polyagamma_samples():
assert
np
.
all
(
np
.
abs
(
np
.
diff
(
bcast_smpl
.
flat
))
>
0.0
)
def
test_rand
om_integer
_samples
():
def
test_rand
int
_samples
():
rv_numpy_tester
(
randint
,
10
,
None
)
rv_numpy_tester
(
randint
,
0
,
1
)
rv_numpy_tester
(
randint
,
0
,
1
,
size
=
[
3
])
rv_numpy_tester
(
randint
,
[
0
,
1
,
2
],
5
)
rv_numpy_tester
(
randint
,
[
0
,
1
,
2
],
5
,
size
=
[
3
,
3
])
rv_numpy_tester
(
randint
,
[
0
],
[
5
],
size
=
[
1
])
rv_numpy_tester
(
randint
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
[
1
])
with
raises
(
TypeError
):
randint
(
10
,
rng
=
shared
(
np
.
random
.
default_rng
()))
rng
=
np
.
random
.
RandomState
(
2313
)
rv_numpy_tester
(
randint
,
10
,
None
,
rng
=
rng
)
rv_numpy_tester
(
randint
,
0
,
1
,
rng
=
rng
)
rv_numpy_tester
(
randint
,
0
,
1
,
size
=
[
3
],
rng
=
rng
)
rv_numpy_tester
(
randint
,
[
0
,
1
,
2
],
5
,
rng
=
rng
)
rv_numpy_tester
(
randint
,
[
0
,
1
,
2
],
5
,
size
=
[
3
,
3
],
rng
=
rng
)
rv_numpy_tester
(
randint
,
[
0
],
[
5
],
size
=
[
1
],
rng
=
rng
)
rv_numpy_tester
(
randint
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
[
1
],
rng
=
rng
)
rv_numpy_tester
(
randint
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
aet
.
as_tensor_variable
([
1
])
randint
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
aet
.
as_tensor_variable
([
1
]),
rng
=
rng
,
)
def
test_integers_samples
():
with
raises
(
TypeError
):
integers
(
10
,
rng
=
shared
(
np
.
random
.
RandomState
()))
rng
=
np
.
random
.
default_rng
(
2313
)
rv_numpy_tester
(
integers
,
10
,
None
,
rng
=
rng
)
rv_numpy_tester
(
integers
,
0
,
1
,
rng
=
rng
)
rv_numpy_tester
(
integers
,
0
,
1
,
size
=
[
3
],
rng
=
rng
)
rv_numpy_tester
(
integers
,
[
0
,
1
,
2
],
5
,
rng
=
rng
)
rv_numpy_tester
(
integers
,
[
0
,
1
,
2
],
5
,
size
=
[
3
,
3
],
rng
=
rng
)
rv_numpy_tester
(
integers
,
[
0
],
[
5
],
size
=
[
1
],
rng
=
rng
)
rv_numpy_tester
(
integers
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
[
1
],
rng
=
rng
)
rv_numpy_tester
(
integers
,
aet
.
as_tensor_variable
([
-
1
]),
[
1
],
size
=
aet
.
as_tensor_variable
([
1
]),
rng
=
rng
,
)
...
...
tests/tensor/random/test_type.py
浏览文件 @
e0eea331
...
...
@@ -6,7 +6,12 @@ import pytest
from
aesara
import
shared
from
aesara.compile.ops
import
ViewOp
from
aesara.tensor.random.type
import
RandomStateType
,
random_state_type
from
aesara.tensor.random.type
import
(
RandomGeneratorType
,
RandomStateType
,
random_generator_type
,
random_state_type
,
)
# @pytest.mark.skipif(
...
...
@@ -24,8 +29,8 @@ def test_view_op_c_code():
# rng_view,
# mode=Mode(optimizer=None, linker=CLinker()),
# )
assert
ViewOp
.
c_code_and_version
[
RandomStateType
]
assert
ViewOp
.
c_code_and_version
[
RandomGeneratorType
]
class
TestRandomStateType
:
...
...
@@ -106,9 +111,112 @@ class TestRandomStateType:
assert
size
==
sys
.
getsizeof
(
rng
.
get_state
(
legacy
=
False
))
def
test_may_share_memory
(
self
):
rng_a
=
np
.
random
.
RandomState
(
12
)
bg
=
np
.
random
.
PCG64
()
rng_b
=
np
.
random
.
RandomState
(
bg
)
bg1
=
np
.
random
.
MT19937
()
bg2
=
np
.
random
.
MT19937
()
rng_a
=
np
.
random
.
RandomState
(
bg1
)
rng_b
=
np
.
random
.
RandomState
(
bg2
)
rng_var_a
=
shared
(
rng_a
,
borrow
=
True
)
rng_var_b
=
shared
(
rng_b
,
borrow
=
True
)
shape_info_a
=
random_state_type
.
get_shape_info
(
rng_var_a
)
shape_info_b
=
random_state_type
.
get_shape_info
(
rng_var_b
)
assert
random_state_type
.
may_share_memory
(
shape_info_a
,
shape_info_b
)
is
False
rng_c
=
np
.
random
.
RandomState
(
bg2
)
rng_var_c
=
shared
(
rng_c
,
borrow
=
True
)
shape_info_c
=
random_state_type
.
get_shape_info
(
rng_var_c
)
assert
random_state_type
.
may_share_memory
(
shape_info_b
,
shape_info_c
)
is
True
class
TestRandomGeneratorType
:
def
test_pickle
(
self
):
rng_r
=
random_generator_type
()
rng_pkl
=
pickle
.
dumps
(
rng_r
)
rng_unpkl
=
pickle
.
loads
(
rng_pkl
)
assert
isinstance
(
rng_unpkl
,
type
(
rng_r
))
assert
isinstance
(
rng_unpkl
.
type
,
type
(
rng_r
.
type
))
def
test_repr
(
self
):
assert
repr
(
random_generator_type
)
==
"RandomGeneratorType"
def
test_filter
(
self
):
rng_type
=
random_generator_type
rng
=
np
.
random
.
default_rng
()
assert
rng_type
.
filter
(
rng
)
is
rng
with
pytest
.
raises
(
TypeError
):
rng_type
.
filter
(
1
)
rng
=
rng
.
__getstate__
()
assert
rng_type
.
is_valid_value
(
rng
,
strict
=
False
)
rng
[
"state"
]
=
{}
assert
rng_type
.
is_valid_value
(
rng
,
strict
=
False
)
is
False
rng
=
{}
assert
rng_type
.
is_valid_value
(
rng
,
strict
=
False
)
is
False
def
test_values_eq
(
self
):
rng_type
=
random_generator_type
bg_1
=
np
.
random
.
PCG64
()
bg_2
=
np
.
random
.
Philox
()
bg_3
=
np
.
random
.
MT19937
()
bg_4
=
np
.
random
.
SFC64
()
bitgen_a
=
np
.
random
.
Generator
(
bg_1
)
bitgen_b
=
np
.
random
.
Generator
(
bg_1
)
assert
rng_type
.
values_eq
(
bitgen_a
,
bitgen_b
)
bitgen_c
=
np
.
random
.
Generator
(
bg_2
)
bitgen_d
=
np
.
random
.
Generator
(
bg_2
)
assert
rng_type
.
values_eq
(
bitgen_c
,
bitgen_d
)
bitgen_e
=
np
.
random
.
Generator
(
bg_3
)
bitgen_f
=
np
.
random
.
Generator
(
bg_3
)
assert
rng_type
.
values_eq
(
bitgen_e
,
bitgen_f
)
bitgen_g
=
np
.
random
.
Generator
(
bg_4
)
bitgen_h
=
np
.
random
.
Generator
(
bg_4
)
assert
rng_type
.
values_eq
(
bitgen_g
,
bitgen_h
)
assert
rng_type
.
is_valid_value
(
bitgen_a
,
strict
=
True
)
assert
rng_type
.
is_valid_value
(
bitgen_b
.
__getstate__
(),
strict
=
False
)
assert
rng_type
.
is_valid_value
(
bitgen_c
,
strict
=
True
)
assert
rng_type
.
is_valid_value
(
bitgen_d
.
__getstate__
(),
strict
=
False
)
assert
rng_type
.
is_valid_value
(
bitgen_e
,
strict
=
True
)
assert
rng_type
.
is_valid_value
(
bitgen_f
.
__getstate__
(),
strict
=
False
)
assert
rng_type
.
is_valid_value
(
bitgen_g
,
strict
=
True
)
assert
rng_type
.
is_valid_value
(
bitgen_h
.
__getstate__
(),
strict
=
False
)
def
test_get_shape_info
(
self
):
rng
=
np
.
random
.
default_rng
(
12
)
rng_a
=
shared
(
rng
)
assert
isinstance
(
random_generator_type
.
get_shape_info
(
rng_a
),
np
.
random
.
Generator
)
def
test_get_size
(
self
):
rng
=
np
.
random
.
Generator
(
np
.
random
.
PCG64
(
12
))
rng_a
=
shared
(
rng
)
shape_info
=
random_generator_type
.
get_shape_info
(
rng_a
)
size
=
random_generator_type
.
get_size
(
shape_info
)
assert
size
==
sys
.
getsizeof
(
rng
.
__getstate__
())
def
test_may_share_memory
(
self
):
bg_a
=
np
.
random
.
PCG64
()
bg_b
=
np
.
random
.
PCG64
()
rng_a
=
np
.
random
.
Generator
(
bg_a
)
rng_b
=
np
.
random
.
Generator
(
bg_b
)
rng_var_a
=
shared
(
rng_a
,
borrow
=
True
)
rng_var_b
=
shared
(
rng_b
,
borrow
=
True
)
...
...
@@ -117,7 +225,7 @@ class TestRandomStateType:
assert
random_state_type
.
may_share_memory
(
shape_info_a
,
shape_info_b
)
is
False
rng_c
=
np
.
random
.
RandomState
(
bg
)
rng_c
=
np
.
random
.
Generator
(
bg_b
)
rng_var_c
=
shared
(
rng_c
,
borrow
=
True
)
shape_info_c
=
random_state_type
.
get_shape_info
(
rng_var_c
)
...
...
tests/tensor/random/test_utils.py
浏览文件 @
e0eea331
差异被折叠。
点击展开。
tests/tensor/random/test_var.py
浏览文件 @
e0eea331
import
numpy
as
np
import
pytest
from
aesara
import
shared
def
test_RandomStateSharedVariable
():
rng
=
np
.
random
.
RandomState
(
123
)
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_GeneratorSharedVariable
(
rng
):
s_rng_default
=
shared
(
rng
)
s_rng_True
=
shared
(
rng
,
borrow
=
True
)
s_rng_False
=
shared
(
rng
,
borrow
=
False
)
...
...
@@ -17,15 +20,22 @@ def test_RandomStateSharedVariable():
assert
s_rng_True
.
container
.
storage
[
0
]
is
rng
# ensure that all the random number generators are in the same state
v
=
rng
.
randn
()
v0
=
s_rng_default
.
container
.
storage
[
0
]
.
randn
()
v1
=
s_rng_False
.
container
.
storage
[
0
]
.
randn
()
assert
v
==
v0
==
v1
if
hasattr
(
rng
,
"randn"
):
v
=
rng
.
randn
()
v0
=
s_rng_default
.
container
.
storage
[
0
]
.
randn
()
v1
=
s_rng_False
.
container
.
storage
[
0
]
.
randn
()
else
:
v
=
rng
.
standard_normal
()
v0
=
s_rng_default
.
container
.
storage
[
0
]
.
standard_normal
()
v1
=
s_rng_False
.
container
.
storage
[
0
]
.
standard_normal
()
assert
v
==
v0
==
v1
def
test_get_value_borrow
():
rng
=
np
.
random
.
RandomState
(
123
)
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_get_value_borrow
(
rng
):
s_rng
=
shared
(
rng
)
r_
=
s_rng
.
container
.
storage
[
0
]
...
...
@@ -39,11 +49,16 @@ def test_get_value_borrow():
assert
r_
is
r_T
# either way, the rngs should all be in the same state
assert
r_
.
rand
()
==
r_F
.
rand
()
if
hasattr
(
rng
,
"rand"
):
assert
r_
.
rand
()
==
r_F
.
rand
()
else
:
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
def
test_get_value_internal_type
():
rng
=
np
.
random
.
RandomState
(
123
)
@pytest.mark.parametrize
(
"rng"
,
[
np
.
random
.
RandomState
(
123
),
np
.
random
.
default_rng
(
123
)]
)
def
test_get_value_internal_type
(
rng
):
s_rng
=
shared
(
rng
)
# there is no special behaviour required of return_internal_type
...
...
@@ -60,23 +75,28 @@ def test_get_value_internal_type():
assert
r_
is
r_T
# either way, the rngs should all be in the same state
assert
r_
.
rand
()
==
r_F
.
rand
()
if
hasattr
(
rng
,
"rand"
):
assert
r_
.
rand
()
==
r_F
.
rand
()
else
:
assert
r_
.
standard_normal
()
==
r_F
.
standard_normal
()
def
test_set_value_borrow
():
rng
=
np
.
random
.
RandomState
(
123
)
s_rng
=
shared
(
rng
)
@pytest.mark.parametrize
(
"rng_ctor"
,
[
np
.
random
.
RandomState
,
np
.
random
.
default_rng
])
def
test_set_value_borrow
(
rng_ctor
):
s_rng
=
shared
(
rng_ctor
(
123
))
new_rng
=
np
.
random
.
RandomState
(
234234
)
new_rng
=
rng_ctor
(
234234
)
# Test the borrow contract is respected:
# assigning with borrow=False makes a copy
s_rng
.
set_value
(
new_rng
,
borrow
=
False
)
assert
new_rng
is
not
s_rng
.
container
.
storage
[
0
]
assert
new_rng
.
randn
()
==
s_rng
.
container
.
storage
[
0
]
.
randn
()
if
hasattr
(
new_rng
,
"randn"
):
assert
new_rng
.
randn
()
==
s_rng
.
container
.
storage
[
0
]
.
randn
()
else
:
assert
new_rng
.
standard_normal
()
==
s_rng
.
container
.
storage
[
0
]
.
standard_normal
()
# Test that the current implementation is actually borrowing when it can.
rr
=
np
.
random
.
RandomState
(
33
)
rr
=
rng_ctor
(
33
)
s_rng
.
set_value
(
rr
,
borrow
=
True
)
assert
rr
is
s_rng
.
container
.
storage
[
0
]
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论