Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b8e939e9
提交
b8e939e9
authored
4月 21, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
4月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix `local_subtensor_rv_lift` rewrite bug with vector parameters
Also allow rewrite to work with multivariate variables, when indexing does not act on support dims.
上级
bfeabc82
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
159 行增加
和
103 行删除
+159
-103
basic.py
pytensor/tensor/random/rewriting/basic.py
+62
-95
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+97
-8
没有找到文件。
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
b8e939e9
from
itertools
import
zip_longest
from
pytensor.compile
import
optdb
from
pytensor.configdefaults
import
config
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor.basic
import
constant
,
get_vector_length
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.extra_ops
import
broadcast_to
...
...
@@ -17,6 +20,7 @@ from pytensor.tensor.subtensor import (
get_idx_list
,
indexed_result_shape
,
)
from
pytensor.tensor.type_other
import
SliceType
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
...
...
@@ -196,37 +200,11 @@ def local_dimshuffle_rv_lift(fgraph, node):
def
local_subtensor_rv_lift
(
fgraph
,
node
):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions
need to be separated into distinct replication-space and (independent)
parameter-space ``*Subtensor``s.
The replication-space ``*Subtensor`` can be used to determine a
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
applied to the distribution parameters.
Consider the following example graph:
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``,
which correspond to all three ``size`` dimensions. Now, depending on the
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op``
could be reducing the ``size`` parameter and/or sub-setting the independent
``mu`` and ``std`` parameters. Only once the dimensions are properly
separated into the two replication/parameter subspaces can we determine how
the ``*Subtensor`` indices are distributed.
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
could become
``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
if ``mu.shape == std.shape == ()``
``normal`` is a rather simple case, because it's univariate. Multivariate
cases require a mapping between the parameter space and the image of the
random variable. This may not always be possible, but for many common
distributions it is. For example, the dimensions of the multivariate
normal's image can be mapped directly to each dimension of its parameters.
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``.
For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``.
This rewrite also applies to multivariate distributions as long
as indexing does not happen within core dimensions, such as in
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""
st_op
=
node
.
op
...
...
@@ -234,103 +212,92 @@ def local_subtensor_rv_lift(fgraph, node):
if
not
isinstance
(
st_op
,
(
AdvancedSubtensor
,
AdvancedSubtensor1
,
Subtensor
)):
return
False
base_rv
=
node
.
inputs
[
0
]
rv
=
node
.
inputs
[
0
]
rv_node
=
rv
.
owner
rv_node
=
base_rv
.
owner
if
not
(
rv_node
and
isinstance
(
rv_node
.
op
,
RandomVariable
)):
return
False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
return
False
rv_op
=
rv_node
.
op
rng
,
size
,
dtype
,
*
dist_params
=
rv_node
.
inputs
# TODO: Remove this once the multi-dimensional changes described below are
# in place.
if
rv_op
.
ndim_supp
>
0
:
return
False
rv_op
=
base_rv
.
owner
.
op
rng
,
size
,
dtype
,
*
dist_params
=
base_rv
.
owner
.
inputs
# Parse indices
idx_list
=
getattr
(
st_op
,
"idx_list"
,
None
)
if
idx_list
:
cdata
=
get_idx_list
(
node
.
inputs
,
idx_list
)
else
:
cdata
=
node
.
inputs
[
1
:]
st_indices
,
st_is_bool
=
zip
(
*
tuple
(
(
as_index_variable
(
i
),
getattr
(
i
,
"dtype"
,
None
)
==
"bool"
)
for
i
in
cdata
)
)
# We need to separate dimensions into replications and independents
num_ind_dims
=
None
if
len
(
dist_params
)
==
1
:
num_ind_dims
=
dist_params
[
0
]
.
ndim
else
:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims
=
max
(
d
.
ndim
for
d
in
dist_params
)
reps_ind_split_idx
=
base_rv
.
ndim
-
(
num_ind_dims
+
rv_op
.
ndim_supp
)
if
len
(
st_indices
)
>
reps_ind_split_idx
:
# These are the indices that need to be applied to the parameters
ind_indices
=
tuple
(
st_indices
[
reps_ind_split_idx
:])
# We need to broadcast the parameters before applying the `*Subtensor*`
# with these indices, because the indices could be referencing broadcast
# dimensions that don't exist (yet)
bcast_dist_params
=
broadcast_params
(
dist_params
,
rv_op
.
ndims_params
)
# TODO: For multidimensional distributions, we need a map that tells us
# which dimensions of the parameters need to be indexed.
#
# For example, `multivariate_normal` would have the following:
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
#
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
# the dimension of the RV's image, and its second parameter's
# (i.e. covariance's) first and second dimensions map directly to the
# dimension of the RV's image.
args_lifted
=
tuple
(
p
[
ind_indices
]
for
p
in
bcast_dist_params
)
else
:
# In this case, no indexing is applied to the parameters; only the
# `size` parameter is affected.
args_lifted
=
dist_params
# Check that indexing does not act on support dims
batched_ndims
=
rv
.
ndim
-
rv_op
.
ndim_supp
if
len
(
st_indices
)
>
batched_ndims
:
# If the last indexes are just dummy `slice(None)` we discard them
st_is_bool
=
st_is_bool
[:
batched_ndims
]
st_indices
,
supp_indices
=
(
st_indices
[:
batched_ndims
],
st_indices
[
batched_ndims
:],
)
for
index
in
supp_indices
:
if
not
(
isinstance
(
index
.
type
,
SliceType
)
and
all
(
NoneConst
.
equals
(
i
)
for
i
in
index
.
owner
.
inputs
)
):
return
False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if
is_rv_used_in_graph
(
rv
,
node
,
fgraph
):
return
False
# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape
=
indexed_result_shape
(
base_rv
.
shape
,
st_indices
)
size_lifted
=
(
output_shape
if
rv_op
.
ndim_supp
==
0
else
output_shape
[:
-
rv_op
.
ndim_supp
]
output_shape_ignoring_bool
=
indexed_result_shape
(
rv
.
shape
,
st_indices
)
new_size_ignoring_boolean
=
(
output_shape_ignoring_bool
if
rv_op
.
ndim_supp
==
0
else
output_shape_ignoring_bool
[:
-
rv_op
.
ndim_supp
]
)
# Boolean indices can actually change the `size` value (compared to just
#
*which* dimensions of `size` are used).
# Boolean indices can actually change the `size` value (compared to just
*which* dimensions of `size` are used).
#
The `indexed_result_shape` helper does not consider this
if
any
(
st_is_bool
):
size_lifted
=
tuple
(
new_size
=
tuple
(
at_sum
(
idx
)
if
is_bool
else
s
for
s
,
is_bool
,
idx
in
zip
(
size_lifted
,
st_is_bool
,
st_indices
[:
(
reps_ind_split_idx
+
1
)]
for
s
,
is_bool
,
idx
in
zip
_longest
(
new_size_ignoring_boolean
,
st_is_bool
,
st_indices
,
fillvalue
=
False
)
)
else
:
new_size
=
new_size_ignoring_boolean
# Update the parameters to reflect the indexed dimensions
new_dist_params
=
[]
for
param
,
param_ndim_supp
in
zip
(
dist_params
,
rv_op
.
ndims_params
):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing
=
batched_ndims
-
(
param
.
ndim
-
param_ndim_supp
)
batched_param
=
shape_padleft
(
param
,
batched_param_dims_missing
)
batched_st_indices
=
[]
for
st_index
,
batched_param_shape
in
zip
(
st_indices
,
batched_param
.
type
.
shape
):
# If we have a degenerate dimension indexing it should always do the job
if
batched_param_shape
==
1
:
batched_st_indices
.
append
(
0
)
else
:
batched_st_indices
.
append
(
st_index
)
new_dist_params
.
append
(
batched_param
[
tuple
(
batched_st_indices
)])
new_node
=
rv_op
.
make_node
(
rng
,
size_lifted
,
dtype
,
*
args_lifted
)
_
,
new_rv
=
new_node
.
outputs
# Create new RV
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
dtype
,
*
new_dist_params
)
new_rv
=
new_node
.
default_output
()
# Calling `Op.make_node` directly circumvents test value computations, so
# we need to compute the test values manually
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
b8e939e9
...
...
@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.tensor
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.random.basic
import
(
categorical
,
dirichlet
,
multinomial
,
multivariate_normal
,
...
...
@@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv(
rewrite
,
op_fn
,
dist_op
,
dist_params
,
size
,
rng
,
name
=
None
):
dist_params_at
=
[]
for
p
in
dist_params
:
p_at
=
at
.
as_tensor
(
p
)
.
type
()
for
i
,
p
in
enumerate
(
dist_params
)
:
p_at
=
at
.
as_tensor
(
p
)
.
type
(
f
"p_{i}"
)
p_at
.
tag
.
test_value
=
p
dist_params_at
.
append
(
p_at
)
...
...
@@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(
3
,
2
,
2
),
),
# A multi-dimensional case
# Only one distribution parameter
(
(
0
,),
True
,
poisson
,
(
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
config
.
floatX
),),
(
3
,
2
,
2
),
),
# Univariate distribution with vector parameters
(
(
np
.
array
([
0
,
2
]),),
True
,
categorical
,
(
np
.
array
([
0.0
,
0.0
,
1.0
],
dtype
=
config
.
floatX
),),
(
4
,),
),
(
(
np
.
array
([
True
,
False
,
True
,
True
]),),
True
,
categorical
,
(
np
.
array
([
0.0
,
0.0
,
1.0
],
dtype
=
config
.
floatX
),),
(
4
,),
),
(
(
np
.
array
([
True
,
False
,
True
]),),
True
,
categorical
,
(
np
.
array
(
[[
1.0
,
0.0
,
0.0
],
[
0.0
,
1.0
,
0.0
],
[
0.0
,
0.0
,
1.0
]],
dtype
=
config
.
floatX
,
),
),
(),
),
(
(
slice
(
None
),
np
.
array
([
True
,
False
,
True
]),
),
True
,
categorical
,
(
np
.
array
(
[[
1.0
,
0.0
,
0.0
],
[
0.0
,
1.0
,
0.0
],
[
0.0
,
0.0
,
1.0
]],
dtype
=
config
.
floatX
,
),
),
(
4
,
3
),
),
# Boolean indexing where output is empty
(
(
np
.
array
([
False
,
False
]),),
True
,
normal
,
(
np
.
array
([[
1.0
,
0.0
,
0.0
]],
dtype
=
config
.
floatX
),),
(
2
,
3
),
),
(
(
np
.
array
([
False
,
False
]),),
True
,
categorical
,
(
np
.
array
(
[[
1.0
,
0.0
,
0.0
],
[
0.0
,
1.0
,
0.0
],
[
0.0
,
0.0
,
1.0
]],
dtype
=
config
.
floatX
,
),
),
(
2
,
3
),
),
# Multivariate cases, indexing only supported if it does not affect core dimensions
(
# Indexing dips into core dimension
(
np
.
array
([
1
]),
0
),
False
,
multivariate_normal
,
...
...
@@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(),
),
# Only one distribution parameter
(
(
0
,),
(
np
.
array
([
0
,
2
])
,),
True
,
poisson
,
(
np
.
array
([[
1
,
2
],
[
3
,
4
]],
dtype
=
config
.
floatX
),),
(
3
,
2
,
2
),
multivariate_normal
,
(
np
.
array
(
[[
-
100
,
-
125
,
-
150
],
[
0
,
0
,
0
],
[
200
,
225
,
250
]],
dtype
=
config
.
floatX
,
),
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
*
1e-6
,
),
(),
),
(
(
np
.
array
([
True
,
False
,
True
]),
slice
(
None
)),
True
,
multivariate_normal
,
(
np
.
array
([
200
,
250
],
dtype
=
config
.
floatX
),
# Second covariance is invalid, to test it is not chosen
np
.
dstack
([
np
.
eye
(
2
),
np
.
eye
(
2
)
*
0
,
np
.
eye
(
2
)])
.
T
.
astype
(
config
.
floatX
)
*
1e-6
,
),
(
3
,),
),
],
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论