Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
591c47e6
提交
591c47e6
authored
5月 09, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Distinguish between size=None and size=() in RandomVariables
上级
98d73d78
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
136 行增加
和
125 行删除
+136
-125
random.py
pytensor/link/jax/dispatch/random.py
+4
-8
random.py
pytensor/link/numba/dispatch/random.py
+32
-15
basic.py
pytensor/tensor/random/basic.py
+7
-10
op.py
pytensor/tensor/random/op.py
+13
-29
basic.py
pytensor/tensor/random/rewriting/basic.py
+8
-9
jax.py
pytensor/tensor/random/rewriting/jax.py
+2
-2
utils.py
pytensor/tensor/random/utils.py
+16
-18
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+26
-20
test_basic.py
tests/tensor/random/test_basic.py
+14
-13
test_op.py
tests/tensor/random/test_op.py
+14
-1
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
591c47e6
...
...
@@ -12,6 +12,7 @@ from pytensor.graph import Constant
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
,
jax_typify
from
pytensor.link.jax.dispatch.shape
import
JAXShapeTuple
from
pytensor.tensor.shape
import
Shape
,
Shape_i
from
pytensor.tensor.type_other
import
NoneTypeT
try
:
...
...
@@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv
=
node
.
outputs
[
1
]
out_dtype
=
rv
.
type
.
dtype
static_shape
=
rv
.
type
.
shape
batch_ndim
=
op
.
batch_ndim
(
node
)
# Try to pass static size directly to JAX
...
...
@@ -102,10 +102,9 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param
=
op
.
size_param
(
node
)
if
isinstance
(
size_param
,
Constant
):
size_tuple
=
tuple
(
size_param
.
data
)
# PyTensor uses empty size to represent size = None
if
len
(
size_tuple
):
if
isinstance
(
size_param
,
Constant
)
and
not
isinstance
(
size_param
.
type
,
NoneTypeT
):
static_size
=
tuple
(
size_param
.
data
)
# If one dimension has unknown size, either the size is determined
...
...
@@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
*
parameters
):
# PyTensor uses empty size to represent size = None
if
jax
.
numpy
.
asarray
(
size
)
.
shape
==
(
0
,):
size
=
None
return
jax_sample_fn
(
op
,
node
=
node
)(
rng
,
size
,
out_dtype
,
*
parameters
)
else
:
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
591c47e6
...
...
@@ -21,6 +21,7 @@ from pytensor.link.utils import (
)
from
pytensor.tensor.basic
import
get_vector_length
from
pytensor.tensor.random.type
import
RandomStateType
from
pytensor.tensor.type_other
import
NoneTypeT
class
RandomStateNumbaType
(
types
.
Type
):
...
...
@@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func):
if
not
isinstance
(
rng_param
.
type
,
RandomStateType
):
raise
TypeError
(
"Numba does not support NumPy `Generator`s"
)
tuple_size
=
int
(
get_vector_length
(
op
.
size_param
(
node
)))
size_param
=
op
.
size_param
(
node
)
size_len
=
(
None
if
isinstance
(
size_param
.
type
,
NoneTypeT
)
else
int
(
get_vector_length
(
size_param
))
)
dist_params
=
op
.
dist_params
(
node
)
size_dims
=
tuple_size
-
max
(
i
.
ndim
for
i
in
dist_params
)
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
...
...
@@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func):
"np_random_func"
,
"numba_vectorize"
,
"to_fixed_tuple"
,
"
tuple_size
"
,
"
size_len
"
,
"size_dims"
,
"rng"
,
"size"
,
...
...
@@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}):
"out_dtype"
:
out_dtype
,
}
if
tuple_size
>
0
:
if
size_len
is
not
None
:
size_dims
=
size_len
-
max
(
i
.
ndim
for
i
in
dist_params
)
random_fn_body
=
dedent
(
f
"""
size = to_fixed_tuple(size,
tuple_size
)
size = to_fixed_tuple(size,
size_len
)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
...
...
@@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
{
"np"
:
np
,
"to_fixed_tuple"
:
numba_ndarray
.
to_fixed_tuple
,
"
tuple_size"
:
tuple_size
,
"
size_len"
:
size_len
,
"size_dims"
:
size_dims
,
}
)
...
...
@@ -305,19 +312,24 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register
(
ptr
.
CategoricalRV
)
def
numba_funcify_CategoricalRV
(
op
:
ptr
.
CategoricalRV
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
size_len
=
int
(
get_vector_length
(
op
.
size_param
(
node
)))
size_param
=
op
.
size_param
(
node
)
size_len
=
(
None
if
isinstance
(
size_param
.
type
,
NoneTypeT
)
else
int
(
get_vector_length
(
size_param
))
)
p_ndim
=
node
.
inputs
[
-
1
]
.
ndim
@numba_basic.numba_njit
def
categorical_rv
(
rng
,
size
,
p
):
if
not
size_len
:
if
size_len
is
None
:
size_tpl
=
p
.
shape
[:
-
1
]
else
:
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
p
=
np
.
broadcast_to
(
p
,
size_tpl
+
p
.
shape
[
-
1
:])
# Workaround https://github.com/numba/numba/issues/8975
if
not
size_len
and
p_ndim
==
1
:
if
size_len
is
None
and
p_ndim
==
1
:
unif_samples
=
np
.
asarray
(
np
.
random
.
uniform
(
0
,
1
))
else
:
unif_samples
=
np
.
random
.
uniform
(
0
,
1
,
size_tpl
)
...
...
@@ -336,13 +348,20 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
alphas_ndim
=
op
.
dist_params
(
node
)[
0
]
.
type
.
ndim
neg_ind_shape_len
=
-
alphas_ndim
+
1
size_len
=
int
(
get_vector_length
(
op
.
size_param
(
node
)))
size_param
=
op
.
size_param
(
node
)
size_len
=
(
None
if
isinstance
(
size_param
.
type
,
NoneTypeT
)
else
int
(
get_vector_length
(
size_param
))
)
if
alphas_ndim
>
1
:
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
alphas
):
if
size_len
>
0
:
if
size_len
is
None
:
samples_shape
=
alphas
.
shape
else
:
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
if
(
0
<
alphas
.
ndim
-
1
<=
len
(
size_tpl
)
...
...
@@ -350,8 +369,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
):
raise
ValueError
(
"Parameters shape and size do not match."
)
samples_shape
=
size_tpl
+
alphas
.
shape
[
-
1
:]
else
:
samples_shape
=
alphas
.
shape
res
=
np
.
empty
(
samples_shape
,
dtype
=
out_dtype
)
alphas_bcast
=
np
.
broadcast_to
(
alphas
,
samples_shape
)
...
...
@@ -365,6 +382,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
alphas
):
if
size_len
is
not
None
:
size
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
return
(
rng
,
np
.
random
.
dirichlet
(
alphas
,
size
))
...
...
@@ -404,8 +422,7 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
@numba_funcify.register
(
ptr
.
PermutationRV
)
def
numba_funcify_permutation
(
op
:
ptr
.
PermutationRV
,
node
,
**
kwargs
):
# PyTensor uses size=() to represent size=None
size_is_none
=
op
.
size_param
(
node
)
.
type
.
shape
==
(
0
,)
size_is_none
=
isinstance
(
op
.
size_param
(
node
)
.
type
,
NoneTypeT
)
batch_ndim
=
op
.
batch_ndim
(
node
)
x_batch_ndim
=
node
.
inputs
[
-
1
]
.
type
.
ndim
-
op
.
ndims_params
[
0
]
...
...
pytensor/tensor/random/basic.py
浏览文件 @
591c47e6
...
...
@@ -914,12 +914,11 @@ class MvNormalRV(RandomVariable):
# multivariate normals (or any other multivariate distributions),
# so we need to implement that here
size
=
tuple
(
size
or
())
if
size
:
if
size
is
None
:
mean
,
cov
=
broadcast_params
([
mean
,
cov
],
[
1
,
2
])
else
:
mean
=
np
.
broadcast_to
(
mean
,
size
+
mean
.
shape
[
-
1
:])
cov
=
np
.
broadcast_to
(
cov
,
size
+
cov
.
shape
[
-
2
:])
else
:
mean
,
cov
=
broadcast_params
([
mean
,
cov
],
[
1
,
2
])
res
=
np
.
empty
(
mean
.
shape
)
for
idx
in
np
.
ndindex
(
mean
.
shape
[:
-
1
]):
...
...
@@ -1800,13 +1799,11 @@ class MultinomialRV(RandomVariable):
@classmethod
def
rng_fn
(
cls
,
rng
,
n
,
p
,
size
):
if
n
.
ndim
>
0
or
p
.
ndim
>
1
:
size
=
tuple
(
size
or
())
if
siz
e
:
if
size
is
None
:
n
,
p
=
broadcast_params
([
n
,
p
],
[
0
,
1
])
els
e
:
n
=
np
.
broadcast_to
(
n
,
size
)
p
=
np
.
broadcast_to
(
p
,
size
+
p
.
shape
[
-
1
:])
else
:
n
,
p
=
broadcast_params
([
n
,
p
],
[
0
,
1
])
res
=
np
.
empty
(
p
.
shape
,
dtype
=
cls
.
dtype
)
for
idx
in
np
.
ndindex
(
p
.
shape
[:
-
1
]):
...
...
@@ -2155,7 +2152,7 @@ class PermutationRV(RandomVariable):
def
rng_fn
(
self
,
rng
,
x
,
size
):
# We don't have access to the node in rng_fn :(
x_batch_ndim
=
x
.
ndim
-
self
.
ndims_params
[
0
]
batch_ndim
=
max
(
x_batch_ndim
,
len
(
size
or
()
))
batch_ndim
=
max
(
x_batch_ndim
,
0
if
size
is
None
else
len
(
size
))
if
batch_ndim
:
# rng.permutation has no concept of batch dims
...
...
pytensor/tensor/random/op.py
浏览文件 @
591c47e6
...
...
@@ -16,7 +16,6 @@ from pytensor.tensor.basic import (
as_tensor_variable
,
concatenate
,
constant
,
get_underlying_scalar_constant_value
,
get_vector_length
,
infer_static_shape
,
)
...
...
@@ -28,7 +27,7 @@ from pytensor.tensor.random.utils import (
)
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
from
pytensor.tensor.utils
import
_parse_gufunc_signature
,
safe_signature
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -196,10 +195,10 @@ class RandomVariable(Op):
def
_infer_shape
(
self
,
size
:
TensorVariable
,
size
:
TensorVariable
|
Variable
,
dist_params
:
Sequence
[
TensorVariable
],
param_shapes
:
Sequence
[
tuple
[
Variable
,
...
]]
|
None
=
None
,
)
->
TensorVariable
|
tuple
[
Scala
rVariable
,
...
]:
)
->
tuple
[
ScalarVariable
|
Tenso
rVariable
,
...
]:
"""Compute the output shape given the size and distribution parameters.
Parameters
...
...
@@ -225,9 +224,9 @@ class RandomVariable(Op):
self
.
_supp_shape_from_params
(
dist_params
,
param_shapes
=
param_shapes
)
)
if
not
isinstance
(
size
.
type
,
NoneTypeT
):
size_len
=
get_vector_length
(
size
)
if
size_len
>
0
:
# Fail early when size is incompatible with parameters
for
i
,
(
param
,
param_ndim_supp
)
in
enumerate
(
zip
(
dist_params
,
self
.
ndims_params
)
...
...
@@ -281,21 +280,11 @@ class RandomVariable(Op):
shape
=
batch_shape
+
supp_shape
if
not
shape
:
shape
=
constant
([],
dtype
=
"int64"
)
return
shape
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
_
,
size
,
*
dist_params
=
node
.
inputs
_
,
size_shape
,
*
param_shapes
=
input_shapes
try
:
size_len
=
get_vector_length
(
size
)
except
ValueError
:
size_len
=
get_underlying_scalar_constant_value
(
size_shape
[
0
])
size
=
tuple
(
size
[
n
]
for
n
in
range
(
size_len
))
_
,
_
,
*
param_shapes
=
input_shapes
shape
=
self
.
_infer_shape
(
size
,
dist_params
,
param_shapes
=
param_shapes
)
...
...
@@ -367,8 +356,8 @@ class RandomVariable(Op):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
shape
=
self
.
_infer_shape
(
size
,
dist_params
)
_
,
static_shape
=
infer_static_shape
(
shape
)
inferred_
shape
=
self
.
_infer_shape
(
size
,
dist_params
)
_
,
static_shape
=
infer_static_shape
(
inferred_
shape
)
inputs
=
(
rng
,
size
,
*
dist_params
)
out_type
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
static_shape
)
...
...
@@ -396,21 +385,14 @@ class RandomVariable(Op):
rng
,
size
,
*
args
=
inputs
# If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to
# represent that. Otherwise, NumPy expects a tuple.
if
np
.
size
(
size
)
==
0
:
size
=
None
else
:
size
=
tuple
(
size
)
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
# otherwise.
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if
not
self
.
inplace
:
rng
=
copy
(
rng
)
rng_var_out
[
0
]
=
rng
if
size
is
not
None
:
size
=
tuple
(
size
)
smpl_val
=
self
.
rng_fn
(
rng
,
*
([
*
args
,
size
]))
if
not
isinstance
(
smpl_val
,
np
.
ndarray
)
or
str
(
smpl_val
.
dtype
)
!=
self
.
dtype
:
...
...
@@ -473,7 +455,9 @@ def vectorize_random_variable(
original_dist_params
=
op
.
dist_params
(
node
)
old_size
=
op
.
size_param
(
node
)
len_old_size
=
get_vector_length
(
old_size
)
len_old_size
=
(
None
if
isinstance
(
old_size
.
type
,
NoneTypeT
)
else
get_vector_length
(
old_size
)
)
original_expanded_dist_params
=
explicit_expand_dims
(
original_dist_params
,
op
.
ndims_params
,
len_old_size
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
591c47e6
...
...
@@ -7,7 +7,7 @@ from pytensor.graph.op import compute_test_value
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.scalar
import
integer_types
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor.basic
import
constant
,
get_vector_length
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.extra_ops
import
broadcast_to
from
pytensor.tensor.random.op
import
RandomVariable
...
...
@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
as_index_variable
,
get_idx_list
,
)
from
pytensor.tensor.type_other
import
SliceType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
...
...
@@ -83,9 +83,11 @@ def local_rv_size_lift(fgraph, node):
rng
,
size
,
*
dist_params
=
node
.
inputs
if
isinstance
(
size
.
type
,
NoneTypeT
):
return
dist_params
=
broadcast_params
(
dist_params
,
node
.
op
.
ndims_params
)
if
get_vector_length
(
size
)
>
0
:
dist_params
=
[
broadcast_to
(
p
,
...
...
@@ -102,8 +104,6 @@ def local_rv_size_lift(fgraph, node):
)
for
i
,
p
in
enumerate
(
dist_params
)
]
else
:
return
new_node
=
node
.
op
.
make_node
(
rng
,
None
,
*
dist_params
)
...
...
@@ -159,11 +159,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
batched_dims
=
rv
.
ndim
-
rv_op
.
ndim_supp
batched_dims_ds_order
=
tuple
(
o
for
o
in
ds_op
.
new_order
if
o
not
in
supp_dims
)
if
isinstance
(
size
.
type
,
NoneTypeT
):
# Make size explicit
missing_size_dims
=
batched_dims
-
get_vector_length
(
size
)
if
missing_size_dims
>
0
:
full_size
=
tuple
(
broadcast_params
(
dist_params
,
rv_op
.
ndims_params
)[
0
]
.
shape
)
size
=
full_size
[:
missing_size_dims
]
+
tuple
(
size
)
shape
=
tuple
(
broadcast_params
(
dist_params
,
rv_op
.
ndims_params
)[
0
]
.
shape
)
size
=
shape
[:
batched_dims
]
# Update the size to reflect the DimShuffled dimensions
new_size
=
[
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
591c47e6
...
...
@@ -158,7 +158,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# No need to materialize arange
return
None
rng
,
size
,
dtype
,
a_scalar_param
,
*
other_params
=
node
.
inputs
rng
,
size
,
a_scalar_param
,
*
other_params
=
node
.
inputs
if
a_scalar_param
.
type
.
ndim
>
0
:
# Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange
...
...
@@ -170,7 +170,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# I.e., we substitute the first `()` by `(a)`
new_props_dict
[
"signature"
]
=
re
.
sub
(
r"\(\)"
,
"(a)"
,
op
.
signature
,
1
)
new_op
=
type
(
op
)(
**
new_props_dict
)
return
new_op
.
make_node
(
rng
,
size
,
dtype
,
a_vector_param
,
*
other_params
)
.
outputs
return
new_op
.
make_node
(
rng
,
size
,
a_vector_param
,
*
other_params
)
.
outputs
random_vars_opt
=
SequenceDB
()
...
...
pytensor/tensor/random/utils.py
浏览文件 @
591c47e6
...
...
@@ -9,8 +9,8 @@ import numpy as np
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.graph.basic
import
Constant
,
Variable
from
pytensor.scalar
import
ScalarVariable
from
pytensor.tensor
import
get_vector_length
from
pytensor.tensor.basic
import
as_tensor_variable
,
cast
,
constant
from
pytensor.tensor
import
NoneConst
,
get_vector_length
from
pytensor.tensor.basic
import
as_tensor_variable
,
cast
from
pytensor.tensor.extra_ops
import
broadcast_arrays
,
broadcast_to
from
pytensor.tensor.math
import
maximum
from
pytensor.tensor.shape
import
shape_padleft
,
specify_shape
...
...
@@ -124,7 +124,7 @@ def broadcast_params(params, ndims_params):
def
explicit_expand_dims
(
params
:
Sequence
[
TensorVariable
],
ndim_params
:
Sequence
[
int
],
size_length
:
int
=
0
,
size_length
:
int
|
None
=
None
,
)
->
list
[
TensorVariable
]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
...
...
@@ -132,9 +132,7 @@ def explicit_expand_dims(
param
.
type
.
ndim
-
ndim_param
for
param
,
ndim_param
in
zip
(
params
,
ndim_params
)
]
if
size_length
:
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
# See: https://github.com/pymc-devs/pytensor/issues/568
if
size_length
is
not
None
:
max_batch_dims
=
size_length
else
:
max_batch_dims
=
max
(
batch_dims
,
default
=
0
)
...
...
@@ -159,30 +157,30 @@ def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
def
normalize_size_param
(
s
iz
e
:
int
|
np
.
ndarray
|
Variable
|
Sequence
|
None
,
s
hap
e
:
int
|
np
.
ndarray
|
Variable
|
Sequence
|
None
,
)
->
Variable
:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if
s
ize
is
None
:
size
=
constant
([],
dtype
=
"int64"
)
elif
isinstance
(
s
iz
e
,
int
):
s
ize
=
as_tensor_variable
([
siz
e
],
ndim
=
1
)
elif
not
isinstance
(
s
iz
e
,
np
.
ndarray
|
Variable
|
Sequence
):
if
s
hape
is
None
or
NoneConst
.
equals
(
shape
)
:
return
NoneConst
elif
isinstance
(
s
hap
e
,
int
):
s
hape
=
as_tensor_variable
([
shap
e
],
ndim
=
1
)
elif
not
isinstance
(
s
hap
e
,
np
.
ndarray
|
Variable
|
Sequence
):
raise
TypeError
(
"Parameter size must be None, an integer, or a sequence with integers."
)
else
:
s
ize
=
cast
(
as_tensor_variable
(
siz
e
,
ndim
=
1
,
dtype
=
"int64"
),
"int64"
)
s
hape
=
cast
(
as_tensor_variable
(
shap
e
,
ndim
=
1
,
dtype
=
"int64"
),
"int64"
)
if
not
isinstance
(
s
iz
e
,
Constant
):
if
not
isinstance
(
s
hap
e
,
Constant
):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
s
ize
=
specify_shape
(
size
,
(
get_vector_length
(
siz
e
),))
s
hape
=
specify_shape
(
shape
,
(
get_vector_length
(
shap
e
),))
assert
not
any
(
s
is
None
for
s
in
s
iz
e
.
type
.
shape
)
assert
s
iz
e
.
dtype
in
int_dtypes
assert
not
any
(
s
is
None
for
s
in
s
hap
e
.
type
.
shape
)
assert
s
hap
e
.
dtype
in
int_dtypes
return
s
iz
e
return
s
hap
e
class
RandomStream
:
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
591c47e6
...
...
@@ -30,6 +30,7 @@ from pytensor.tensor.random.rewriting import (
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
,
ShapeOptimizer
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
,
AdvancedSubtensor1
,
Subtensor
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type_other
import
NoneConst
no_mode
=
Mode
(
"py"
,
RewriteDatabaseQuery
(
include
=
[],
exclude
=
[]))
...
...
@@ -44,6 +45,9 @@ def apply_local_rewrite_to_rv(
p_pt
.
tag
.
test_value
=
p
dist_params_pt
.
append
(
p_pt
)
if
size
is
None
:
size_pt
=
NoneConst
else
:
size_pt
=
[]
for
s
in
size
:
# To test DimShuffle with dropping dims we need that size dimension to be constant
...
...
@@ -57,7 +61,9 @@ def apply_local_rewrite_to_rv(
dist_st
=
op_fn
(
dist_op
(
*
dist_params_pt
,
size
=
size_pt
,
rng
=
rng
,
name
=
name
))
f_inputs
=
[
p
for
p
in
dist_params_pt
+
size_pt
if
not
isinstance
(
p
,
slice
|
Constant
)
p
for
p
in
dist_params_pt
+
([]
if
size
is
None
else
size_pt
)
if
not
isinstance
(
p
,
slice
|
Constant
)
]
mode
=
Mode
(
...
...
@@ -135,7 +141,7 @@ def test_inplace_rewrites(rv_op):
np
.
array
([
0.0
,
1.0
],
dtype
=
config
.
floatX
),
np
.
array
(
5.0
,
dtype
=
config
.
floatX
),
],
[]
,
None
,
),
(
normal
,
...
...
@@ -180,7 +186,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
rng
,
)
assert
pt
.
get_vector_length
(
new_out
.
owner
.
inputs
[
1
])
==
0
assert
new_out
.
owner
.
op
.
size_param
(
new_out
.
owner
)
.
data
is
None
@pytest.mark.parametrize
(
...
...
@@ -194,7 +200,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np
.
array
([
0.0
,
-
100.0
],
dtype
=
np
.
float64
),
np
.
array
(
1e-6
,
dtype
=
np
.
float64
),
),
()
,
None
,
1e-7
,
),
(
...
...
@@ -205,7 +211,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np
.
array
(
-
10.0
,
dtype
=
np
.
float64
),
np
.
array
(
1e-6
,
dtype
=
np
.
float64
),
),
()
,
None
,
1e-7
,
),
(
...
...
@@ -216,7 +222,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np
.
array
(
-
10.0
,
dtype
=
np
.
float64
),
np
.
array
(
1e-6
,
dtype
=
np
.
float64
),
),
()
,
None
,
1e-7
,
),
(
...
...
@@ -227,7 +233,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np
.
arange
(
2
*
2
*
2
)
.
reshape
((
2
,
2
,
2
))
.
astype
(
config
.
floatX
),
np
.
array
(
1e-6
)
.
astype
(
config
.
floatX
),
),
()
,
None
,
1e-3
,
),
(
...
...
@@ -440,7 +446,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
5
,
2
),
np
.
full
((
1
,
5
,
1
),
1e-6
),
),
()
,
None
,
),
(
# `size`-only slice
...
...
@@ -462,7 +468,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
5
,
2
),
np
.
full
((
1
,
5
,
1
),
1e-6
),
),
()
,
None
,
),
(
# `size`-only slice
...
...
@@ -484,7 +490,7 @@ def rand_bool_mask(shape, rng=None):
(
0.1
-
1e-5
)
*
np
.
arange
(
4
)
.
astype
(
dtype
=
config
.
floatX
),
0.1
*
np
.
arange
(
4
)
.
astype
(
dtype
=
config
.
floatX
),
),
()
,
None
,
),
# 5
(
...
...
@@ -570,7 +576,7 @@ def rand_bool_mask(shape, rng=None):
dtype
=
config
.
floatX
,
),
),
()
,
None
,
),
(
# Univariate distribution with core-vector parameters
...
...
@@ -627,7 +633,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
(
# Multidimensional boolean indexing
...
...
@@ -638,7 +644,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
(
# Multidimensional boolean indexing
...
...
@@ -649,7 +655,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
# 20
(
...
...
@@ -661,7 +667,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
(
# Multidimensional boolean indexing
...
...
@@ -687,7 +693,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
(
# Multidimensional boolean indexing,
...
...
@@ -703,7 +709,7 @@ def rand_bool_mask(shape, rng=None):
np
.
arange
(
30
)
.
reshape
(
5
,
3
,
2
),
1e-6
,
),
()
,
None
,
),
(
# Multivariate distribution: indexing dips into core dimension
...
...
@@ -714,7 +720,7 @@ def rand_bool_mask(shape, rng=None):
np
.
array
([[
-
1
,
20
],
[
300
,
-
4000
]],
dtype
=
config
.
floatX
),
np
.
eye
(
2
)
.
astype
(
config
.
floatX
)
*
1e-6
,
),
()
,
None
,
),
# 25
(
...
...
@@ -726,7 +732,7 @@ def rand_bool_mask(shape, rng=None):
np
.
array
([[
-
1
,
20
],
[
300
,
-
4000
]],
dtype
=
config
.
floatX
),
np
.
eye
(
2
)
.
astype
(
config
.
floatX
)
*
1e-6
,
),
()
,
None
,
),
(
# Multivariate distribution: advanced integer indexing
...
...
@@ -740,7 +746,7 @@ def rand_bool_mask(shape, rng=None):
),
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
*
1e-6
,
),
()
,
None
,
),
(
# Multivariate distribution: dummy slice "dips" into core dimension
...
...
tests/tensor/random/test_basic.py
浏览文件 @
591c47e6
...
...
@@ -212,7 +212,7 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
@pytest.mark.parametrize
(
"M, sd, size"
,
[
(
pt
.
as_tensor_variable
(
np
.
array
(
1.0
,
dtype
=
config
.
floatX
)),
sd_pt
,
()
),
(
pt
.
as_tensor_variable
(
np
.
array
(
1.0
,
dtype
=
config
.
floatX
)),
sd_pt
,
None
),
(
pt
.
as_tensor_variable
(
np
.
array
(
1.0
,
dtype
=
config
.
floatX
)),
sd_pt
,
...
...
@@ -223,10 +223,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
sd_pt
,
(
2
,
M_pt
),
),
(
pt
.
zeros
((
M_pt
,)),
sd_pt
,
()
),
(
pt
.
zeros
((
M_pt
,)),
sd_pt
,
None
),
(
pt
.
zeros
((
M_pt
,)),
sd_pt
,
(
M_pt
,)),
(
pt
.
zeros
((
M_pt
,)),
sd_pt
,
(
2
,
M_pt
)),
(
pt
.
zeros
((
M_pt
,)),
pt
.
ones
((
M_pt
,)),
()
),
(
pt
.
zeros
((
M_pt
,)),
pt
.
ones
((
M_pt
,)),
None
),
(
pt
.
zeros
((
M_pt
,)),
pt
.
ones
((
M_pt
,)),
(
2
,
M_pt
)),
(
create_pytensor_param
(
...
...
@@ -244,9 +244,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
)
def
test_normal_infer_shape
(
M
,
sd
,
size
):
rv
=
normal
(
M
,
sd
,
size
=
size
)
rv_shape
=
list
(
normal
.
_infer_shape
(
size
or
(),
[
M
,
sd
],
None
))
size_pt
=
rv
.
owner
.
op
.
size_param
(
rv
.
owner
)
rv_shape
=
list
(
normal
.
_infer_shape
(
size_pt
,
[
M
,
sd
],
None
))
all_args
=
(
M
,
sd
,
*
size
)
all_args
=
(
M
,
sd
,
*
(()
if
size
is
None
else
size
)
)
fn_inputs
=
[
i
for
i
in
graph_inputs
([
a
for
a
in
all_args
if
isinstance
(
a
,
Variable
)])
...
...
@@ -525,8 +526,8 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
mean
=
np
.
array
([
0.0
],
dtype
=
config
.
floatX
)
if
cov
is
None
:
cov
=
np
.
array
([[
1.0
]],
dtype
=
config
.
floatX
)
if
size
is
None
:
size
=
(
)
if
size
is
not
None
:
size
=
tuple
(
size
)
return
multivariate_normal
.
rng_fn
(
random_state
,
mean
,
cov
,
size
)
...
...
@@ -713,19 +714,20 @@ M_pt.tag.test_value = 3
@pytest.mark.parametrize
(
"M, size"
,
[
(
pt
.
ones
((
M_pt
,)),
()
),
(
pt
.
ones
((
M_pt
,)),
None
),
(
pt
.
ones
((
M_pt
,)),
(
M_pt
+
1
,)),
(
pt
.
ones
((
M_pt
,)),
(
2
,
M_pt
)),
(
pt
.
ones
((
M_pt
,
M_pt
+
1
)),
()
),
(
pt
.
ones
((
M_pt
,
M_pt
+
1
)),
None
),
(
pt
.
ones
((
M_pt
,
M_pt
+
1
)),
(
M_pt
+
2
,
M_pt
)),
(
pt
.
ones
((
M_pt
,
M_pt
+
1
)),
(
2
,
M_pt
+
2
,
M_pt
+
3
,
M_pt
)),
],
)
def
test_dirichlet_infer_shape
(
M
,
size
):
rv
=
dirichlet
(
M
,
size
=
size
)
rv_shape
=
list
(
dirichlet
.
_infer_shape
(
size
or
(),
[
M
],
None
))
size_pt
=
rv
.
owner
.
op
.
size_param
(
rv
.
owner
)
rv_shape
=
list
(
dirichlet
.
_infer_shape
(
size_pt
,
[
M
],
None
))
all_args
=
(
M
,
*
size
)
all_args
=
(
M
,
*
(()
if
size
is
None
else
size
)
)
fn_inputs
=
[
i
for
i
in
graph_inputs
([
a
for
a
in
all_args
if
isinstance
(
a
,
Variable
)])
...
...
@@ -1620,8 +1622,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
@config.change_flags
(
compute_test_value
=
"off"
)
def
test_pickle
():
# This is an interesting `Op` case, because it has `None` types and a
# conditional dtype
# This is an interesting `Op` case, because it has a conditional dtype
sample_a
=
choice
(
5
,
replace
=
False
,
size
=
(
2
,
3
))
a_pkl
=
pickle
.
dumps
(
sample_a
)
...
...
tests/tensor/random/test_op.py
浏览文件 @
591c47e6
...
...
@@ -69,7 +69,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# `RandomVariable._infer_shape` should handle no parameters
rv_shape
=
rv
.
_infer_shape
(
pt
.
constant
([]),
(),
[])
assert
rv_shape
.
equals
(
pt
.
constant
([],
dtype
=
"int64"
)
)
assert
rv_shape
==
(
)
# `dtype` is respected
rv
=
RandomVariable
(
"normal"
,
signature
=
"(),()->()"
,
dtype
=
"int32"
)
...
...
@@ -299,3 +299,16 @@ def test_vectorize():
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
2
,
5
)
def
test_size_none_vs_empty
():
rv
=
RandomVariable
(
"normal"
,
signature
=
"(),()->()"
,
)
assert
rv
([
0
],
[
1
],
size
=
None
)
.
type
.
shape
==
(
1
,)
with
pytest
.
raises
(
ValueError
,
match
=
"Size length is incompatible with batched dimensions"
):
rv
([
0
],
[
1
],
size
=
())
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论