Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
98d73d78
提交
98d73d78
authored
5月 10, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove RandomVariable dtype input
上级
df32683c
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
118 行增加
和
122 行删除
+118
-122
random.py
pytensor/link/jax/dispatch/random.py
+2
-2
random.py
pytensor/link/numba/dispatch/random.py
+8
-9
op.py
pytensor/tensor/random/op.py
+35
-33
basic.py
pytensor/tensor/random/rewriting/basic.py
+6
-6
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+6
-5
test_op.py
tests/tensor/random/test_op.py
+61
-67
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
98d73d78
...
@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
...
@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
if
None
in
static_size
:
if
None
in
static_size
:
assert_size_argument_jax_compatible
(
node
)
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
*
parameters
):
# PyTensor uses empty size to represent size = None
# PyTensor uses empty size to represent size = None
if
jax
.
numpy
.
asarray
(
size
)
.
shape
==
(
0
,):
if
jax
.
numpy
.
asarray
(
size
)
.
shape
==
(
0
,):
size
=
None
size
=
None
...
@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
...
@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
else
:
else
:
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
*
parameters
):
return
jax_sample_fn
(
op
,
node
=
node
)(
return
jax_sample_fn
(
op
,
node
=
node
)(
rng
,
static_size
,
out_dtype
,
*
parameters
rng
,
static_size
,
out_dtype
,
*
parameters
)
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
98d73d78
...
@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
...
@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
"size_dims"
,
"size_dims"
,
"rng"
,
"rng"
,
"size"
,
"size"
,
"dtype"
,
],
],
suffix_sep
=
"_"
,
suffix_sep
=
"_"
,
)
)
...
@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
...
@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
)
)
random_fn_input_names
=
", "
.
join
(
random_fn_input_names
=
", "
.
join
(
[
"rng"
,
"size"
,
"dtype"
]
+
[
unique_names
(
i
)
for
i
in
dist_params
]
[
"rng"
,
"size"
]
+
[
unique_names
(
i
)
for
i
in
dist_params
]
)
)
# Now, create a Numba JITable function that implements the `size` parameter
# Now, create a Numba JITable function that implements the `size` parameter
...
@@ -243,7 +242,7 @@ def create_numba_random_fn(
...
@@ -243,7 +242,7 @@ def create_numba_random_fn(
np_global_env
[
"numba_vectorize"
]
=
numba_basic
.
numba_vectorize
np_global_env
[
"numba_vectorize"
]
=
numba_basic
.
numba_vectorize
unique_names
=
unique_name_generator
(
unique_names
=
unique_name_generator
(
[
np_random_fn_name
,
*
np_global_env
.
keys
(),
"rng"
,
"size"
,
"dtype"
],
[
np_random_fn_name
,
*
np_global_env
.
keys
(),
"rng"
,
"size"
],
suffix_sep
=
"_"
,
suffix_sep
=
"_"
,
)
)
...
@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
...
@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
p_ndim
=
node
.
inputs
[
-
1
]
.
ndim
p_ndim
=
node
.
inputs
[
-
1
]
.
ndim
@numba_basic.numba_njit
@numba_basic.numba_njit
def
categorical_rv
(
rng
,
size
,
dtype
,
p
):
def
categorical_rv
(
rng
,
size
,
p
):
if
not
size_len
:
if
not
size_len
:
size_tpl
=
p
.
shape
[:
-
1
]
size_tpl
=
p
.
shape
[:
-
1
]
else
:
else
:
...
@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
...
@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
if
alphas_ndim
>
1
:
if
alphas_ndim
>
1
:
@numba_basic.numba_njit
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
dtype
,
alphas
):
def
dirichlet_rv
(
rng
,
size
,
alphas
):
if
size_len
>
0
:
if
size_len
>
0
:
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
if
(
if
(
...
@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
...
@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
else
:
else
:
@numba_basic.numba_njit
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
dtype
,
alphas
):
def
dirichlet_rv
(
rng
,
size
,
alphas
):
size
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
size
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
return
(
rng
,
np
.
random
.
dirichlet
(
alphas
,
size
))
return
(
rng
,
np
.
random
.
dirichlet
(
alphas
,
size
))
...
@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
...
@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
if
op
.
has_p_param
:
if
op
.
has_p_param
:
@numba_basic.numba_njit
@numba_basic.numba_njit
def
choice_without_replacement_rv
(
rng
,
size
,
dtype
,
a
,
p
,
core_shape
):
def
choice_without_replacement_rv
(
rng
,
size
,
a
,
p
,
core_shape
):
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
,
p
=
p
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
,
p
=
p
)
return
(
rng
,
samples
)
return
(
rng
,
samples
)
else
:
else
:
@numba_basic.numba_njit
@numba_basic.numba_njit
def
choice_without_replacement_rv
(
rng
,
size
,
dtype
,
a
,
core_shape
):
def
choice_without_replacement_rv
(
rng
,
size
,
a
,
core_shape
):
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
)
return
(
rng
,
samples
)
return
(
rng
,
samples
)
...
@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
...
@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
x_batch_ndim
=
node
.
inputs
[
-
1
]
.
type
.
ndim
-
op
.
ndims_params
[
0
]
x_batch_ndim
=
node
.
inputs
[
-
1
]
.
type
.
ndim
-
op
.
ndims_params
[
0
]
@numba_basic.numba_njit
@numba_basic.numba_njit
def
permutation_rv
(
rng
,
size
,
dtype
,
x
):
def
permutation_rv
(
rng
,
size
,
x
):
if
batch_ndim
:
if
batch_ndim
:
x_core_shape
=
x
.
shape
[
x_batch_ndim
:]
x_core_shape
=
x
.
shape
[
x_batch_ndim
:]
if
size_is_none
:
if
size_is_none
:
...
...
pytensor/tensor/random/op.py
浏览文件 @
98d73d78
...
@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import (
...
@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import (
normalize_size_param
,
normalize_size_param
,
)
)
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.type
import
TensorType
,
all_dtypes
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.utils
import
_parse_gufunc_signature
,
safe_signature
from
pytensor.tensor.utils
import
_parse_gufunc_signature
,
safe_signature
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -65,7 +65,7 @@ class RandomVariable(Op):
...
@@ -65,7 +65,7 @@ class RandomVariable(Op):
signature: str
signature: str
Numpy-like vectorized signature of the random variable.
Numpy-like vectorized signature of the random variable.
dtype: str (optional)
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
The d
efault d
type of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
`RandomVariable.make_node` is called.
...
@@ -287,8 +287,8 @@ class RandomVariable(Op):
...
@@ -287,8 +287,8 @@ class RandomVariable(Op):
return
shape
return
shape
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
_
,
size
,
_
,
*
dist_params
=
node
.
inputs
_
,
size
,
*
dist_params
=
node
.
inputs
_
,
size_shape
,
_
,
*
param_shapes
=
input_shapes
_
,
size_shape
,
*
param_shapes
=
input_shapes
try
:
try
:
size_len
=
get_vector_length
(
size
)
size_len
=
get_vector_length
(
size
)
...
@@ -302,14 +302,34 @@ class RandomVariable(Op):
...
@@ -302,14 +302,34 @@ class RandomVariable(Op):
return
[
None
,
list
(
shape
)]
return
[
None
,
list
(
shape
)]
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
res
=
super
()
.
__call__
(
rng
,
size
,
dtype
,
*
args
,
**
kwargs
)
if
dtype
is
None
:
dtype
=
self
.
dtype
if
dtype
==
"floatX"
:
dtype
=
config
.
floatX
# We need to recreate the Op with the right dtype
if
dtype
!=
self
.
dtype
:
# Check we are not switching from float to int
if
self
.
dtype
is
not
None
:
if
dtype
.
startswith
(
"float"
)
!=
self
.
dtype
.
startswith
(
"float"
):
raise
ValueError
(
f
"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
)
props
=
self
.
_props_dict
()
props
[
"dtype"
]
=
dtype
new_op
=
type
(
self
)(
**
props
)
return
new_op
.
__call__
(
*
args
,
size
=
size
,
name
=
name
,
rng
=
rng
,
dtype
=
dtype
,
**
kwargs
)
res
=
super
()
.
__call__
(
rng
,
size
,
*
args
,
**
kwargs
)
if
name
is
not
None
:
if
name
is
not
None
:
res
.
name
=
name
res
.
name
=
name
return
res
return
res
def
make_node
(
self
,
rng
,
size
,
dtype
,
*
dist_params
):
def
make_node
(
self
,
rng
,
size
,
*
dist_params
):
"""Create a random variable node.
"""Create a random variable node.
Parameters
Parameters
...
@@ -349,23 +369,10 @@ class RandomVariable(Op):
...
@@ -349,23 +369,10 @@ class RandomVariable(Op):
shape
=
self
.
_infer_shape
(
size
,
dist_params
)
shape
=
self
.
_infer_shape
(
size
,
dist_params
)
_
,
static_shape
=
infer_static_shape
(
shape
)
_
,
static_shape
=
infer_static_shape
(
shape
)
dtype
=
self
.
dtype
or
dtype
if
dtype
==
"floatX"
:
inputs
=
(
rng
,
size
,
*
dist_params
)
dtype
=
config
.
floatX
out_type
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
static_shape
)
elif
dtype
is
None
or
(
isinstance
(
dtype
,
str
)
and
dtype
not
in
all_dtypes
):
outputs
=
(
rng
.
type
(),
out_type
())
raise
TypeError
(
"dtype is unspecified"
)
if
isinstance
(
dtype
,
str
):
dtype_idx
=
constant
(
all_dtypes
.
index
(
dtype
),
dtype
=
"int64"
)
else
:
dtype_idx
=
constant
(
dtype
,
dtype
=
"int64"
)
dtype
=
all_dtypes
[
dtype_idx
.
data
]
inputs
=
(
rng
,
size
,
dtype_idx
,
*
dist_params
)
out_var
=
TensorType
(
dtype
=
dtype
,
shape
=
static_shape
)()
outputs
=
(
rng
.
type
(),
out_var
)
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
...
@@ -382,14 +389,12 @@ class RandomVariable(Op):
...
@@ -382,14 +389,12 @@ class RandomVariable(Op):
def
dist_params
(
self
,
node
)
->
Sequence
[
Variable
]:
def
dist_params
(
self
,
node
)
->
Sequence
[
Variable
]:
"""Return the node inpust corresponding to dist params"""
"""Return the node inpust corresponding to dist params"""
return
node
.
inputs
[
3
:]
return
node
.
inputs
[
2
:]
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
rng_var_out
,
smpl_out
=
outputs
rng_var_out
,
smpl_out
=
outputs
rng
,
size
,
dtype
,
*
args
=
inputs
rng
,
size
,
*
args
=
inputs
out_var
=
node
.
outputs
[
1
]
# If `size == []`, that means no size is enforced, and NumPy is trusted
# If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to
# to draw the appropriate number of samples, NumPy uses `size=None` to
...
@@ -408,11 +413,8 @@ class RandomVariable(Op):
...
@@ -408,11 +413,8 @@ class RandomVariable(Op):
smpl_val
=
self
.
rng_fn
(
rng
,
*
([
*
args
,
size
]))
smpl_val
=
self
.
rng_fn
(
rng
,
*
([
*
args
,
size
]))
if
(
if
not
isinstance
(
smpl_val
,
np
.
ndarray
)
or
str
(
smpl_val
.
dtype
)
!=
self
.
dtype
:
not
isinstance
(
smpl_val
,
np
.
ndarray
)
smpl_val
=
_asarray
(
smpl_val
,
dtype
=
self
.
dtype
)
or
str
(
smpl_val
.
dtype
)
!=
out_var
.
type
.
dtype
):
smpl_val
=
_asarray
(
smpl_val
,
dtype
=
out_var
.
type
.
dtype
)
smpl_out
[
0
]
=
smpl_val
smpl_out
[
0
]
=
smpl_val
...
@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp()
...
@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp()
@_vectorize_node.register
(
RandomVariable
)
@_vectorize_node.register
(
RandomVariable
)
def
vectorize_random_variable
(
def
vectorize_random_variable
(
op
:
RandomVariable
,
node
:
Apply
,
rng
,
size
,
dtype
,
*
dist_params
op
:
RandomVariable
,
node
:
Apply
,
rng
,
size
,
*
dist_params
)
->
Apply
:
)
->
Apply
:
# If size was provided originally and a new size hasn't been provided,
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
# We extend it to accommodate the new input batch dimensions.
...
@@ -491,4 +493,4 @@ def vectorize_random_variable(
...
@@ -491,4 +493,4 @@ def vectorize_random_variable(
new_size_dims
=
broadcasted_batch_shape
[:
new_ndim
]
new_size_dims
=
broadcasted_batch_shape
[:
new_ndim
]
size
=
concatenate
([
new_size_dims
,
size
])
size
=
concatenate
([
new_size_dims
,
size
])
return
op
.
make_node
(
rng
,
size
,
dtype
,
*
dist_params
)
return
op
.
make_node
(
rng
,
size
,
*
dist_params
)
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
98d73d78
...
@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
...
@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
if
not
isinstance
(
node
.
op
,
RandomVariable
):
if
not
isinstance
(
node
.
op
,
RandomVariable
):
return
return
rng
,
size
,
dtype
,
*
dist_params
=
node
.
inputs
rng
,
size
,
*
dist_params
=
node
.
inputs
dist_params
=
broadcast_params
(
dist_params
,
node
.
op
.
ndims_params
)
dist_params
=
broadcast_params
(
dist_params
,
node
.
op
.
ndims_params
)
...
@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
...
@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
else
:
else
:
return
return
new_node
=
node
.
op
.
make_node
(
rng
,
None
,
dtype
,
*
dist_params
)
new_node
=
node
.
op
.
make_node
(
rng
,
None
,
*
dist_params
)
if
config
.
compute_test_value
!=
"off"
:
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
compute_test_value
(
new_node
)
...
@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return
False
return
False
rv_op
=
rv_node
.
op
rv_op
=
rv_node
.
op
rng
,
size
,
dtype
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rv
=
rv_node
.
default_output
()
rv
=
rv_node
.
default_output
()
# Check that Dimshuffle does not affect support dims
# Check that Dimshuffle does not affect support dims
...
@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
)
)
new_dist_params
.
append
(
param
.
dimshuffle
(
param_new_order
))
new_dist_params
.
append
(
param
.
dimshuffle
(
param_new_order
))
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
dtype
,
*
new_dist_params
)
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
if
config
.
compute_test_value
!=
"off"
:
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
compute_test_value
(
new_node
)
...
@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node):
return
None
return
None
rv_op
=
rv_node
.
op
rv_op
=
rv_node
.
op
rng
,
size
,
dtype
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
# Parse indices
# Parse indices
idx_list
=
getattr
(
subtensor_op
,
"idx_list"
,
None
)
idx_list
=
getattr
(
subtensor_op
,
"idx_list"
,
None
)
...
@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node):
new_dist_params
.
append
(
batch_param
[
tuple
(
batch_indices
)])
new_dist_params
.
append
(
batch_param
[
tuple
(
batch_indices
)])
# Create new RV
# Create new RV
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
dtype
,
*
new_dist_params
)
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
new_rv
=
new_node
.
default_output
()
new_rv
=
new_node
.
default_output
()
copy_stack_trace
(
rv
,
new_rv
)
copy_stack_trace
(
rv
,
new_rv
)
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
98d73d78
...
@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
...
@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.tensor
import
constant
from
pytensor.tensor
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.random.basic
import
(
from
pytensor.tensor.random.basic
import
(
NormalRV
,
categorical
,
categorical
,
dirichlet
,
dirichlet
,
multinomial
,
multinomial
,
...
@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
...
@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
)
)
if
lifted
:
if
lifted
:
assert
new_out
.
owner
.
op
==
dist_op
assert
isinstance
(
new_out
.
owner
.
op
,
type
(
dist_op
))
assert
all
(
assert
all
(
isinstance
(
i
.
owner
.
op
,
DimShuffle
)
isinstance
(
i
.
owner
.
op
,
DimShuffle
)
for
i
in
new_out
.
owner
.
op
.
dist_params
(
new_out
.
owner
)
for
i
in
new_out
.
owner
.
op
.
dist_params
(
new_out
.
owner
)
...
@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
...
@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
subtensor_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
subtensor_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
subtensor_node
==
y
.
owner
assert
subtensor_node
==
y
.
owner
assert
isinstance
(
subtensor_node
.
op
,
Subtensor
)
assert
isinstance
(
subtensor_node
.
op
,
Subtensor
)
assert
subtensor_node
.
inputs
[
0
]
.
owner
.
op
==
normal
assert
isinstance
(
subtensor_node
.
inputs
[
0
]
.
owner
.
op
,
NormalRV
)
z
=
pt
.
ones
(
x
.
shape
)
-
x
[
1
]
z
=
pt
.
ones
(
x
.
shape
)
-
x
[
1
]
...
@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
...
@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
EquilibriumGraphRewriter
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
EquilibriumGraphRewriter
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
rv_node
.
op
==
normal
assert
isinstance
(
rv_node
.
op
,
NormalRV
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
Subtensor
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
Subtensor
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
Subtensor
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
Subtensor
)
...
@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
...
@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
dimshuffle_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
dimshuffle_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
dimshuffle_node
==
y
.
owner
assert
dimshuffle_node
==
y
.
owner
assert
isinstance
(
dimshuffle_node
.
op
,
DimShuffle
)
assert
isinstance
(
dimshuffle_node
.
op
,
DimShuffle
)
assert
dimshuffle_node
.
inputs
[
0
]
.
owner
.
op
==
normal
assert
isinstance
(
dimshuffle_node
.
inputs
[
0
]
.
owner
.
op
,
NormalRV
)
z
=
pt
.
ones
(
x
.
shape
)
-
y
z
=
pt
.
ones
(
x
.
shape
)
-
y
...
@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
...
@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
EquilibriumGraphRewriter
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
EquilibriumGraphRewriter
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
rv_node
.
op
==
normal
assert
isinstance
(
rv_node
.
op
,
NormalRV
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
DimShuffle
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
DimShuffle
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
DimShuffle
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
DimShuffle
)
...
...
tests/tensor/random/test_op.py
浏览文件 @
98d73d78
...
@@ -3,14 +3,14 @@ import pytest
...
@@ -3,14 +3,14 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
function
from
pytensor
import
config
,
function
from
pytensor.gradient
import
NullTypeGradError
,
grad
from
pytensor.graph.replace
import
vectorize_graph
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.basic
import
NormalRV
from
pytensor.tensor.random.op
import
RandomState
,
RandomVariable
,
default_rng
from
pytensor.tensor.random.op
import
RandomState
,
RandomVariable
,
default_rng
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.type
import
all_dtypes
,
iscalar
,
tensor
from
pytensor.tensor.type
import
iscalar
,
tensor
@pytest.fixture
(
scope
=
"function"
,
autouse
=
False
)
@pytest.fixture
(
scope
=
"function"
,
autouse
=
False
)
...
@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags):
...
@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags):
inplace
=
True
,
inplace
=
True
,
)(
0
,
1
,
size
=
{
1
,
2
})
)(
0
,
1
,
size
=
{
1
,
2
})
# No dtype
with
pytest
.
raises
(
TypeError
,
match
=
"^dtype*"
):
RandomVariable
(
"normal"
,
0
,
[
0
,
0
],
inplace
=
True
,
)(
0
,
1
)
# Confirm that `inplace` works
# Confirm that `inplace` works
rv
=
RandomVariable
(
rv
=
RandomVariable
(
"normal"
,
"normal"
,
...
@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
...
@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
rv_shape
=
rv
.
_infer_shape
(
pt
.
constant
([]),
(),
[])
rv_shape
=
rv
.
_infer_shape
(
pt
.
constant
([]),
(),
[])
assert
rv_shape
.
equals
(
pt
.
constant
([],
dtype
=
"int64"
))
assert
rv_shape
.
equals
(
pt
.
constant
([],
dtype
=
"int64"
))
# Integer-specified `dtype`
# `dtype` is respected
dtype_1
=
all_dtypes
[
1
]
rv
=
RandomVariable
(
"normal"
,
signature
=
"(),()->()"
,
dtype
=
"int32"
)
rv_node
=
rv
.
make_node
(
None
,
None
,
1
)
with
config
.
change_flags
(
compute_test_value
=
"off"
):
rv_out
=
rv_node
.
outputs
[
1
]
rv_out
=
rv
()
rv_out
.
tag
.
test_value
=
1
assert
rv_out
.
dtype
==
"int32"
rv_out
=
rv
(
dtype
=
"int64"
)
assert
rv_out
.
dtype
==
"int64"
assert
rv_out
.
dtype
==
dtype_1
with
pytest
.
raises
(
ValueError
,
with
pytest
.
raises
(
NullTypeGradError
):
match
=
"Cannot change the dtype of a normal RV from int32 to float32"
,
grad
(
rv_out
,
[
rv_node
.
inputs
[
0
]])
):
assert
rv
(
dtype
=
"float32"
)
.
dtype
==
"float32"
def
test_RandomVariable_bcast
(
strict_test_value_flags
):
def
test_RandomVariable_bcast
(
strict_test_value_flags
):
...
@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape():
...
@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape():
assert
mv_op
(
param1
,
param2
,
size
=
(
10
,
2
))
.
type
.
shape
==
(
10
,
2
,
3
)
assert
mv_op
(
param1
,
param2
,
size
=
(
10
,
2
))
.
type
.
shape
==
(
10
,
2
,
3
)
def
test_vectorize
_node
():
def
test_vectorize
():
vec
=
tensor
(
shape
=
(
None
,))
vec
=
tensor
(
shape
=
(
None
,))
mat
=
tensor
(
shape
=
(
None
,
None
))
mat
=
tensor
(
shape
=
(
None
,
None
))
# Test without size
# Test without size
node
=
normal
(
vec
)
.
owner
out
=
normal
(
vec
)
new_inputs
=
node
.
inputs
.
copy
()
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
})
.
owner
new_inputs
[
3
]
=
mat
# mu
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
assert
vect_node
.
op
is
normal
assert
vect_node
.
inputs
[
3
]
is
mat
# Test with size, new size provided
# Test with size, new size provided
node
=
normal
(
vec
,
size
=
(
3
,))
.
owner
size
=
pt
.
as_tensor
(
np
.
array
((
3
,),
dtype
=
"int64"
))
new_inputs
=
node
.
inputs
.
copy
()
out
=
normal
(
vec
,
size
=
size
)
new_inputs
[
1
]
=
(
2
,
3
)
# size
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
,
size
:
(
2
,
3
)})
.
owner
new_inputs
[
3
]
=
mat
# mu
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
tuple
(
vect_node
.
op
.
size_param
(
vect_node
)
.
eval
())
==
(
2
,
3
)
assert
vect_node
.
op
is
normal
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
assert
tuple
(
vect_node
.
inputs
[
1
]
.
eval
())
==
(
2
,
3
)
assert
vect_node
.
inputs
[
3
]
is
mat
# Test with size, new size not provided
# Test with size, new size not provided
node
=
normal
(
vec
,
size
=
(
3
,))
.
owner
out
=
normal
(
vec
,
size
=
(
3
,))
new_inputs
=
node
.
inputs
.
copy
()
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
})
.
owner
new_inputs
[
3
]
=
mat
# mu
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
assert
vect_node
.
op
is
normal
assert
vect_node
.
inputs
[
3
]
is
mat
assert
tuple
(
assert
tuple
(
vect_node
.
inputs
[
1
]
.
eval
({
mat
:
np
.
zeros
((
2
,
3
),
dtype
=
config
.
floatX
)})
vect_node
.
op
.
size_param
(
vect_node
)
.
eval
(
{
mat
:
np
.
zeros
((
2
,
3
),
dtype
=
config
.
floatX
)}
)
)
==
(
2
,
3
)
)
==
(
2
,
3
)
# Test parameter broadcasting
# Test parameter broadcasting
node
=
normal
(
vec
)
.
owner
mu
=
vec
new_inputs
=
node
.
inputs
.
copy
()
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
out
=
normal
(
mu
,
sigma
)
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
new_mu
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
assert
vect_node
.
op
is
normal
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
# Test parameter broadcasting with non-expanding size
# Test parameter broadcasting with non-expanding size
node
=
normal
(
vec
,
size
=
(
5
,))
.
owner
mu
=
vec
new_inputs
=
node
.
inputs
.
copy
()
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
out
=
normal
(
mu
,
sigma
,
size
=
(
5
,))
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
new_mu
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
assert
vect_node
.
op
is
normal
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
node
=
normal
(
vec
,
size
=
(
5
,))
.
owner
mu
=
vec
new_inputs
=
node
.
inputs
.
copy
()
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
# mu
out
=
normal
(
mu
,
sigma
,
size
=
(
5
,))
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
new_mu
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
# mu
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
assert
vect_node
.
op
is
normal
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
# Test parameter broadcasting with expanding size
# Test parameter broadcasting with expanding size
node
=
normal
(
vec
,
size
=
(
2
,
5
))
.
owner
mu
=
vec
new_inputs
=
node
.
inputs
.
copy
()
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
out
=
normal
(
mu
,
sigma
,
size
=
(
2
,
5
))
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
new_mu
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
assert
vect_node
.
op
is
normal
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
2
,
5
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
2
,
5
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论