Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2d81ccae
提交
2d81ccae
authored
6月 28, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 03, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify RV rewrites
上级
94e9ef06
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
61 行增加
和
76 行删除
+61
-76
basic.py
pytensor/tensor/random/rewriting/basic.py
+61
-76
没有找到文件。
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
2d81ccae
...
@@ -5,21 +5,20 @@ from pytensor.configdefaults import config
...
@@ -5,21 +5,20 @@ from pytensor.configdefaults import config
from
pytensor.graph
import
ancestors
from
pytensor.graph
import
ancestors
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
node_rewriter
from
pytensor.scalar
import
integer_types
from
pytensor.tensor
import
NoneConst
,
TensorVariable
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.basic
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.extra_ops
import
broadcast_to
from
pytensor.tensor.extra_ops
import
broadcast_to
from
pytensor.tensor.random.op
import
RandomVariable
from
pytensor.tensor.random.op
import
RandomVariable
from
pytensor.tensor.random.utils
import
broadcast_params
from
pytensor.tensor.random.utils
import
broadcast_params
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
shape_padleft
from
pytensor.tensor.shape
import
Shape
,
Shape_i
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
AdvancedSubtensor
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor1
,
Subtensor
,
Subtensor
,
as_index_variable
,
get_idx_list
,
get_idx_list
,
)
)
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
...
@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
ds_op
=
node
.
op
ds_op
=
node
.
op
if
not
isinstance
(
ds_op
,
DimShuffle
):
# Dimshuffle which drop dimensions not supported yet
if
ds_op
.
drop
:
return
False
return
False
base_rv
=
node
.
inputs
[
0
]
rv_node
=
node
.
inputs
[
0
]
.
owner
rv_node
=
base_rv
.
owner
if
not
(
rv_node
and
isinstance
(
rv_node
.
op
,
RandomVariable
)):
if
not
(
rv_node
and
isinstance
(
rv_node
.
op
,
RandomVariable
)):
return
False
return
False
# Dimshuffle which drop dimensions not supported yet
if
ds_op
.
drop
:
return
False
rv_op
=
rv_node
.
op
rv_op
=
rv_node
.
op
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rv
=
rv_node
.
default_output
()
next_rng
,
rv
=
rv_node
.
outputs
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
rv
,
node
,
fgraph
):
return
False
# Check that Dimshuffle does not affect support dims
# Check that Dimshuffle does not affect support dims
supp_dims
=
set
(
range
(
rv
.
ndim
-
rv_op
.
ndim_supp
,
rv
.
ndim
))
supp_dims
=
set
(
range
(
rv
.
ndim
-
rv_op
.
ndim_supp
,
rv
.
ndim
))
...
@@ -153,31 +153,24 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -153,31 +153,24 @@ def local_dimshuffle_rv_lift(fgraph, node):
# If no one else is using the underlying RandomVariable, then we can
# If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent.
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
base_
rv
,
node
,
fgraph
):
if
is_rv_used_in_graph
(
rv
,
node
,
fgraph
):
return
False
return
False
batched_dims
=
rv
.
ndim
-
rv_op
.
ndim_supp
batched_dims
=
rv
.
ndim
-
rv_op
.
ndim_supp
batched_dims_ds_order
=
tuple
(
o
for
o
in
ds_op
.
new_order
if
o
not
in
supp_dims
)
batched_dims_ds_order
=
tuple
(
o
for
o
in
ds_op
.
new_order
if
o
not
in
supp_dims
)
if
isinstance
(
size
.
type
,
NoneTypeT
):
if
isinstance
(
size
.
type
,
NoneTypeT
):
# Make size explicit
new_size
=
size
shape
=
tuple
(
broadcast_params
(
dist_params
,
rv_op
.
ndims_params
)[
0
]
.
shape
)
else
:
size
=
shape
[:
batched_dims
]
# Update the size to reflect the DimShuffled dimensions
new_size
=
[
# Update the size to reflect the DimShuffled dimensions
constant
(
1
,
dtype
=
"int64"
)
if
o
==
"x"
else
size
[
o
]
new_size
=
[
for
o
in
batched_dims_ds_order
constant
(
1
,
dtype
=
"int64"
)
if
o
==
"x"
else
size
[
o
]
]
for
o
in
batched_dims_ds_order
]
# Updates the params to reflect the Dimshuffled dimensions
# Updates the params to reflect the Dimshuffled dimensions
new_dist_params
=
[]
new_dist_params
=
[]
for
param
,
param_ndim_supp
in
zip
(
dist_params
,
rv_op
.
ndims_params
):
for
param
,
param_ndim_supp
in
zip
(
dist_params
,
rv_op
.
ndims_params
):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft
=
batched_dims
-
(
param
.
ndim
-
param_ndim_supp
)
if
padleft
>
0
:
param
=
shape_padleft
(
param
,
padleft
)
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order
=
batched_dims_ds_order
+
tuple
(
param_new_order
=
batched_dims_ds_order
+
tuple
(
range
(
batched_dims
,
batched_dims
+
param_ndim_supp
)
range
(
batched_dims
,
batched_dims
+
param_ndim_supp
)
...
@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
if
config
.
compute_test_value
!=
"off"
:
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
compute_test_value
(
new_node
)
out
=
new_node
.
outputs
[
1
]
new_rv
=
new_node
.
default_output
()
if
base_
rv
.
name
:
if
rv
.
name
:
out
.
name
=
f
"{base_
rv.name}_lifted"
new_rv
.
name
=
f
"{
rv.name}_lifted"
return
[
out
]
return
[
new_rv
]
@node_rewriter
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
@node_rewriter
([
Subtensor
,
AdvancedSubtensor1
,
AdvancedSubtensor
])
...
@@ -206,7 +199,9 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -206,7 +199,9 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``.
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""
"""
def
is_nd_advanced_idx
(
idx
,
dtype
):
def
is_nd_advanced_idx
(
idx
,
dtype
)
->
bool
:
if
not
isinstance
(
idx
,
TensorVariable
):
return
False
if
isinstance
(
dtype
,
str
):
if
isinstance
(
dtype
,
str
):
return
(
getattr
(
idx
.
type
,
"dtype"
,
None
)
==
dtype
)
and
(
idx
.
type
.
ndim
>=
1
)
return
(
getattr
(
idx
.
type
,
"dtype"
,
None
)
==
dtype
)
and
(
idx
.
type
.
ndim
>=
1
)
else
:
else
:
...
@@ -214,39 +209,28 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -214,39 +209,28 @@ def local_subtensor_rv_lift(fgraph, node):
subtensor_op
=
node
.
op
subtensor_op
=
node
.
op
old_subtensor
=
node
.
outputs
[
0
]
[
indexed_rv
]
=
node
.
outputs
rv
=
node
.
inputs
[
0
]
rv_node
=
node
.
inputs
[
0
]
.
owner
rv_node
=
rv
.
owner
if
not
(
rv_node
and
isinstance
(
rv_node
.
op
,
RandomVariable
)):
if
not
(
rv_node
and
isinstance
(
rv_node
.
op
,
RandomVariable
)):
return
False
return
False
shape_feature
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
if
not
shape_feature
:
return
None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape
=
fgraph
.
shape_feature
.
shape_of
.
get
(
old_subtensor
,
None
)
if
output_shape
is
None
or
{
old_subtensor
,
rv
}
&
set
(
ancestors
(
output_shape
)):
return
None
rv_op
=
rv_node
.
op
rv_op
=
rv_node
.
op
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rv
=
rv_node
.
default_output
()
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
rv
,
node
,
fgraph
):
return
False
# Parse indices
# Parse indices
idx_list
=
getattr
(
subtensor_op
,
"idx_list"
,
None
)
indices
=
get_idx_list
(
node
.
inputs
,
getattr
(
subtensor_op
,
"idx_list"
,
None
))
if
idx_list
:
idx_vars
=
get_idx_list
(
node
.
inputs
,
idx_list
)
else
:
idx_vars
=
node
.
inputs
[
1
:]
indices
=
tuple
(
as_index_variable
(
idx
)
for
idx
in
idx_vars
)
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# and make use of the dimshuffle lift rewrite
integer_dtypes
=
{
type
.
dtype
for
type
in
integer_types
}
if
any
(
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
NoneConst
.
equals
(
idx
)
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
NoneConst
.
equals
(
idx
)
for
idx
in
indices
for
idx
in
indices
...
@@ -277,13 +261,21 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -277,13 +261,21 @@ def local_subtensor_rv_lift(fgraph, node):
n_discarded_idxs
=
len
(
supp_indices
)
n_discarded_idxs
=
len
(
supp_indices
)
indices
=
indices
[:
-
n_discarded_idxs
]
indices
=
indices
[:
-
n_discarded_idxs
]
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
rv
,
node
,
fgraph
):
return
False
# Update the size to reflect the indexed dimensions
# Update the size to reflect the indexed dimensions
new_size
=
output_shape
[:
len
(
output_shape
)
-
rv_op
.
ndim_supp
]
if
isinstance
(
size
.
type
,
NoneTypeT
):
new_size
=
size
else
:
shape_feature
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
if
not
shape_feature
:
return
None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape
=
fgraph
.
shape_feature
.
shape_of
.
get
(
indexed_rv
,
None
)
if
output_shape
is
None
or
{
indexed_rv
,
rv
}
&
set
(
ancestors
(
output_shape
)):
return
None
new_size
=
output_shape
[:
len
(
output_shape
)
-
rv_op
.
ndim_supp
]
# Propagate indexing to the parameters' batch dims.
# Propagate indexing to the parameters' batch dims.
# We try to avoid broadcasting the parameters together (and with size), by only indexing
# We try to avoid broadcasting the parameters together (and with size), by only indexing
...
@@ -291,20 +283,13 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -291,20 +283,13 @@ def local_subtensor_rv_lift(fgraph, node):
# should still correctly broadcast any degenerate parameter dims.
# should still correctly broadcast any degenerate parameter dims.
new_dist_params
=
[]
new_dist_params
=
[]
for
param
,
param_ndim_supp
in
zip
(
dist_params
,
rv_op
.
ndims_params
):
for
param
,
param_ndim_supp
in
zip
(
dist_params
,
rv_op
.
ndims_params
):
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
# Check which dims are broadcasted by either size or other parameters
batch_param_dims_missing
=
batch_ndims
-
(
param
.
ndim
-
param_ndim_supp
)
bcast_param_dims
=
tuple
(
batch_param
=
(
shape_padleft
(
param
,
batch_param_dims_missing
)
if
batch_param_dims_missing
else
param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims
=
tuple
(
dim
dim
for
dim
,
(
param_dim
,
output_dim
)
in
enumerate
(
for
dim
,
(
param_dim
_bcast
,
output_dim_bcast
)
in
enumerate
(
zip
(
batch_param
.
type
.
shape
,
rv
.
type
.
shap
e
)
zip
(
param
.
type
.
broadcastable
,
rv
.
type
.
broadcastabl
e
)
)
)
if
(
param_dim
==
1
)
and
(
output_dim
!=
1
)
if
param_dim_bcast
and
not
output_dim_bcast
)
)
batch_indices
=
[]
batch_indices
=
[]
curr_dim
=
0
curr_dim
=
0
...
@@ -315,23 +300,23 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -315,23 +300,23 @@ def local_subtensor_rv_lift(fgraph, node):
# If not, we use that directly, instead of the more inefficient `nonzero` form
# If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims
=
range
(
curr_dim
,
curr_dim
+
idx
.
type
.
ndim
)
bool_dims
=
range
(
curr_dim
,
curr_dim
+
idx
.
type
.
ndim
)
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
if
set
(
bool_dims
)
&
set
(
bcast_
batch_
param_dims
):
if
set
(
bool_dims
)
&
set
(
bcast_param_dims
):
int_indices
=
list
(
idx
.
nonzero
())
int_indices
=
list
(
idx
.
nonzero
())
# Indexing by 0 drops the degenerate dims
# Indexing by 0 drops the degenerate dims
for
bool_dim
in
bool_dims
:
for
bool_dim
in
bool_dims
:
if
bool_dim
in
bcast_
batch_
param_dims
:
if
bool_dim
in
bcast_param_dims
:
int_indices
[
bool_dim
-
curr_dim
]
=
0
int_indices
[
bool_dim
-
curr_dim
]
=
0
batch_indices
.
extend
(
int_indices
)
batch_indices
.
extend
(
int_indices
)
# No overlap, use index as is
# No overlap, use
boolean
index as is
else
:
else
:
batch_indices
.
append
(
idx
)
batch_indices
.
append
(
idx
)
curr_dim
+=
len
(
bool_dims
)
curr_dim
+=
len
(
bool_dims
)
# Basic-indexing (slice or integer)
# Basic-indexing (slice or integer)
else
:
else
:
# Broadcasted dim
# Broadcasted dim
if
curr_dim
in
bcast_
batch_
param_dims
:
if
curr_dim
in
bcast_param_dims
:
# Slice indexing, keep degenerate dim by none-slicing
# Slice indexing, keep degenerate dim by none-slicing
if
isinstance
(
idx
.
type
,
SliceType
):
if
isinstance
(
idx
,
slice
)
or
isinstance
(
idx
.
type
,
SliceType
):
batch_indices
.
append
(
slice
(
None
))
batch_indices
.
append
(
slice
(
None
))
# Integer indexing, drop degenerate dim by 0-indexing
# Integer indexing, drop degenerate dim by 0-indexing
else
:
else
:
...
@@ -342,7 +327,7 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -342,7 +327,7 @@ def local_subtensor_rv_lift(fgraph, node):
batch_indices
.
append
(
idx
)
batch_indices
.
append
(
idx
)
curr_dim
+=
1
curr_dim
+=
1
new_dist_params
.
append
(
batch_
param
[
tuple
(
batch_indices
)])
new_dist_params
.
append
(
param
[
tuple
(
batch_indices
)])
# Create new RV
# Create new RV
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论