Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cfa76f5d
提交
cfa76f5d
authored
11月 18, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 19, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Handle non-constant NoneTypeT variables
上级
4e4f237a
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
78 行增加
和
34 行删除
+78
-34
op.py
pytensor/tensor/random/op.py
+3
-1
basic.py
pytensor/tensor/random/rewriting/basic.py
+17
-14
utils.py
pytensor/tensor/random/utils.py
+14
-11
shape.py
pytensor/tensor/rewriting/shape.py
+4
-4
shape.py
pytensor/tensor/shape.py
+8
-3
test_op.py
tests/tensor/random/test_op.py
+10
-0
test_utils.py
tests/tensor/random/test_utils.py
+22
-1
没有找到文件。
pytensor/tensor/random/op.py
浏览文件 @
cfa76f5d
...
...
@@ -385,7 +385,9 @@ class RandomVariable(RNGConsumerOp):
dist_params
=
explicit_expand_dims
(
dist_params
,
self
.
ndims_params
,
size_length
=
None
if
NoneConst
.
equals
(
size
)
else
get_vector_length
(
size
),
size_length
=
None
if
isinstance
(
size
.
type
,
NoneTypeT
)
else
get_vector_length
(
size
),
)
inputs
=
(
rng
,
size
,
*
dist_params
)
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
cfa76f5d
...
...
@@ -9,7 +9,7 @@ from pytensor.graph.rewriting.basic import (
dfs_rewriter
,
node_rewriter
,
)
from
pytensor.tensor
import
NoneConst
,
TensorVariable
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.extra_ops
import
broadcast_to
...
...
@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor
,
AdvancedSubtensor1
,
Subtensor
,
get_idx_list
,
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
...
...
@@ -237,17 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
return
False
# Parse indices
indices
=
get_idx_list
(
node
.
inputs
,
getattr
(
subtensor_op
,
"idx_list"
,
None
))
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
NoneConst
.
equals
(
idx
)
for
idx
in
indices
):
return
False
if
isinstance
(
subtensor_op
,
Subtensor
):
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
subtensor_op
.
idx_list
)
else
:
indices
=
node
.
inputs
[
1
:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
indices
):
return
False
# Check that indexing does not act on support dims
batch_ndims
=
rv_op
.
batch_ndim
(
rv_node
)
...
...
@@ -267,7 +270,7 @@ def local_subtensor_rv_lift(fgraph, node):
for
idx
in
supp_indices
:
if
not
(
isinstance
(
idx
.
type
,
SliceType
)
and
all
(
NoneConst
.
equals
(
i
)
for
i
in
idx
.
owner
.
inputs
)
and
all
(
isinstance
(
i
.
type
,
NoneTypeT
)
for
i
in
idx
.
owner
.
inputs
)
):
return
False
n_discarded_idxs
=
len
(
supp_indices
)
...
...
pytensor/tensor/random/utils.py
浏览文件 @
cfa76f5d
...
...
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
import
numpy
as
np
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.graph.basic
import
Constant
,
Variable
from
pytensor.graph.basic
import
Variable
from
pytensor.scalar
import
ScalarVariable
from
pytensor.tensor
import
NoneConst
,
get_vector_length
from
pytensor.tensor.basic
import
as_tensor_variable
,
cast
...
...
@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from
pytensor.tensor.math
import
maximum
from
pytensor.tensor.shape
import
shape_padleft
,
specify_shape
from
pytensor.tensor.type
import
int_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.utils
import
faster_broadcast_to
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -178,24 +179,26 @@ def normalize_size_param(
shape
:
int
|
np
.
ndarray
|
Variable
|
Sequence
|
None
,
)
->
Variable
:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if
shape
is
None
or
NoneConst
.
equals
(
shape
)
:
if
shape
is
None
:
return
NoneConst
elif
isinstance
(
shape
,
int
):
if
isinstance
(
shape
,
Variable
)
and
isinstance
(
shape
.
type
,
NoneTypeT
):
return
shape
if
isinstance
(
shape
,
int
):
shape
=
as_tensor_variable
([
shape
],
ndim
=
1
)
elif
not
isinstance
(
shape
,
np
.
ndarray
|
Variable
|
Sequence
):
raise
TypeError
(
"Parameter size must be None, an integer, or a sequence with integers."
)
else
:
if
not
isinstance
(
shape
,
Sequence
|
Variable
|
np
.
ndarray
):
raise
TypeError
(
"Parameter size must be None, an integer, or a sequence with integers."
)
shape
=
cast
(
as_tensor_variable
(
shape
,
ndim
=
1
,
dtype
=
"int64"
),
"int64"
)
if
not
isinstance
(
shape
,
Constant
):
if
shape
.
type
.
shape
==
(
None
,
):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
# will be available after certain types of cloning (e.g. the kind `Scan` performs)
shape
=
specify_shape
(
shape
,
(
get_vector_length
(
shape
),))
assert
not
any
(
s
is
None
for
s
in
shape
.
type
.
shape
)
assert
shape
.
type
.
shape
!=
(
None
,
)
assert
shape
.
dtype
in
int_dtypes
return
shape
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
cfa76f5d
...
...
@@ -47,7 +47,7 @@ from pytensor.tensor.shape import (
)
from
pytensor.tensor.subtensor
import
Subtensor
,
get_idx_list
from
pytensor.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
from
pytensor.tensor.type_other
import
None
Const
,
None
TypeT
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -1137,7 +1137,7 @@ def local_merge_consecutive_specify_shape(fgraph, node):
inner_obj
,
*
shape
=
obj
.
owner
.
inputs
for
dim
,
sh
in
enumerate
(
node
.
inputs
[
1
:]):
if
not
NoneConst
.
equals
(
sh
):
if
not
isinstance
(
sh
.
type
,
NoneTypeT
):
shape
[
dim
]
=
sh
# TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are
...
...
@@ -1183,7 +1183,7 @@ def local_Shape_of_SpecifyShape(fgraph, node):
# Replace `NoneConst` by `shape_i`
for
i
,
sh
in
enumerate
(
shape
):
if
NoneConst
.
equals
(
sh
):
if
isinstance
(
sh
.
type
,
NoneTypeT
):
shape
[
i
]
=
x
.
shape
[
i
]
return
[
stack
(
shape
)
.
astype
(
np
.
int64
)]
...
...
@@ -1219,7 +1219,7 @@ def local_specify_shape_lift(fgraph, node):
for
i
,
(
dim
,
bcast
)
in
enumerate
(
zip
(
shape
,
out_broadcastable
,
strict
=
True
)
)
if
(
not
bcast
and
not
NoneConst
.
equals
(
dim
))
if
(
not
bcast
and
not
isinstance
(
dim
.
type
,
NoneTypeT
))
}
new_elem_inps
=
elem_inps
.
copy
()
for
i
,
elem_inp
in
enumerate
(
elem_inps
):
...
...
pytensor/tensor/shape.py
浏览文件 @
cfa76f5d
...
...
@@ -408,7 +408,9 @@ class SpecifyShape(COp):
shape
=
tuple
(
NoneConst
if
(
s
is
None
or
NoneConst
.
equals
(
s
))
if
(
s
is
None
or
(
isinstance
(
s
,
Variable
)
and
isinstance
(
s
.
type
,
NoneTypeT
))
)
else
ptb
.
as_tensor_variable
(
s
,
ndim
=
0
)
for
s
in
shape
)
...
...
@@ -506,7 +508,7 @@ class SpecifyShape(COp):
for
i
,
(
shp_name
,
shp
)
in
enumerate
(
zip
(
shape_names
,
node
.
inputs
[
1
:],
strict
=
True
)
):
if
NoneConst
.
equals
(
shp
):
if
isinstance
(
shp
.
type
,
NoneTypeT
):
continue
code
+=
dedent
(
f
"""
...
...
@@ -594,7 +596,10 @@ def _vectorize_specify_shape(op, node, x, *shape):
if
any
(
as_tensor_variable
(
dim
)
.
type
.
ndim
!=
0
for
dim
in
shape
if
not
(
NoneConst
.
equals
(
dim
)
or
dim
is
None
)
if
not
(
(
isinstance
(
dim
,
Variable
)
and
isinstance
(
dim
.
type
,
NoneTypeT
))
or
dim
is
None
)
):
raise
NotImplementedError
(
"It is not possible to vectorize the shape argument of SpecifyShape"
...
...
tests/tensor/random/test_op.py
浏览文件 @
cfa76f5d
...
...
@@ -11,6 +11,7 @@ from pytensor.tensor.random.basic import NormalRV
from
pytensor.tensor.random.op
import
RandomVariable
,
default_rng
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.type
import
iscalar
,
tensor
from
pytensor.tensor.type_other
import
none_type_t
@pytest.fixture
(
scope
=
"function"
,
autouse
=
False
)
...
...
@@ -317,3 +318,12 @@ def test_size_none_vs_empty():
ValueError
,
match
=
"Size length is incompatible with batched dimensions"
):
rv
([
0
],
[
1
],
size
=
())
def
test_non_constant_none_size
():
# Regression test for https://github.com/pymc-devs/pymc/issues/7901#issuecomment-3528479876
loc
=
pt
.
vector
(
"loc"
,
dtype
=
"float64"
)
size
=
none_type_t
(
"none_size"
)
rv
=
normal
(
loc
,
size
=
size
)
rv
.
eval
({
loc
:
np
.
arange
(
5
,
dtype
=
"float64"
),
size
:
None
},
mode
=
"FAST_COMPILE"
)
tests/tensor/random/test_utils.py
浏览文件 @
cfa76f5d
...
...
@@ -7,9 +7,11 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.tensor.random.utils
import
(
RandomStream
,
broadcast_params
,
normalize_size_param
,
supp_shape_from_ref_param_shape
,
)
from
pytensor.tensor.type
import
matrix
,
tensor
from
pytensor.tensor.type
import
TensorType
,
matrix
,
tensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
none_type_t
from
tests
import
unittest_tools
as
utt
...
...
@@ -327,3 +329,22 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx
=
1
,
)
assert
res
==
(
3
,
4
)
def
test_normalize_size_param
():
assert
normalize_size_param
(
None
)
.
type
==
NoneTypeT
()
sym_none_size
=
none_type_t
()
assert
normalize_size_param
(
sym_none_size
)
is
sym_none_size
empty_size
=
normalize_size_param
(())
assert
empty_size
.
type
==
TensorType
(
dtype
=
"int64"
,
shape
=
(
0
,))
int_size
=
normalize_size_param
(
5
)
assert
int_size
.
type
==
TensorType
(
dtype
=
"int64"
,
shape
=
(
1
,))
seq_int_size
=
normalize_size_param
((
5
,
3
,
4
))
assert
seq_int_size
.
type
==
TensorType
(
dtype
=
"int64"
,
shape
=
(
3
,))
sym_tensor_size
=
tensor
(
shape
=
(
3
,),
dtype
=
"int64"
)
assert
normalize_size_param
(
sym_tensor_size
)
is
sym_tensor_size
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论