Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc6bed1a
提交
cc6bed1a
authored
2月 28, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
3月 01, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Refactor AdvancedSubtensor"
This reverts commit
db7fa079
.
上级
03afa5bb
隐藏空白字符变更
内嵌
并排
正在显示
26 个修改的文件
包含
1568 行增加
和
1354 行删除
+1568
-1354
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-2
subtensor.py
pytensor/link/jax/dispatch/subtensor.py
+33
-3
subtensor.py
pytensor/link/mlx/dispatch/subtensor.py
+18
-5
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+67
-104
subtensor.py
pytensor/link/pytorch/dispatch/subtensor.py
+19
-10
rewriting.py
pytensor/scan/rewriting.py
+24
-4
basic.py
pytensor/tensor/basic.py
+4
-4
basic.py
pytensor/tensor/random/rewriting/basic.py
+17
-9
shape.py
pytensor/tensor/rewriting/shape.py
+3
-5
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+160
-157
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+29
-21
uncanonicalize.py
pytensor/tensor/rewriting/uncanonicalize.py
+52
-34
subtensor.py
pytensor/tensor/subtensor.py
+812
-739
variable.py
pytensor/tensor/variable.py
+73
-48
indexing.py
pytensor/xtensor/rewriting/indexing.py
+3
-4
test_basic.py
tests/graph/rewriting/test_basic.py
+26
-0
test_subtensor.py
tests/link/jax/test_subtensor.py
+0
-31
test_subtensor.py
tests/link/mlx/test_subtensor.py
+21
-0
test_subtensor.py
tests/link/numba/test_subtensor.py
+62
-20
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+6
-12
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+16
-6
test_subtensor_lift.py
tests/tensor/rewriting/test_subtensor_lift.py
+4
-3
test_blockwise.py
tests/tensor/test_blockwise.py
+4
-8
test_subtensor.py
tests/tensor/test_subtensor.py
+81
-110
test_type_other.py
tests/tensor/test_type_other.py
+29
-12
test_variable.py
tests/tensor/test_variable.py
+3
-3
没有找到文件。
pytensor/graph/destroyhandler.py
浏览文件 @
cc6bed1a
...
...
@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
}
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_aliased"
,
()
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
)
assert
isinstance
(
tolerate_aliased
,
tuple
|
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
{
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
}
...
...
pytensor/link/jax/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
BOOLEAN_MASK_ERROR
=
"""JAX does not support resizing arrays with boolean
...
...
@@ -34,8 +35,10 @@ slice length.
@jax_funcify.register
(
AdvancedSubtensor
)
@jax_funcify.register
(
AdvancedSubtensor1
)
def
jax_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register
(
IncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor1
)
def
jax_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
...
...
@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return
jax_fn
(
x
,
indices
,
y
)
return
incsubtensor
@jax_funcify.register
(
AdvancedIncSubtensor
)
def
jax_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
else
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
advancedincsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
return
jax_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
@jax_funcify.register
(
MakeSlice
)
def
jax_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/mlx/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
@mlx_funcify.register
(
Subtensor
)
def
mlx_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
[
int
(
element
)
for
element
in
ilists
],
op
.
idx_list
)
indices
=
indices_from_subtensor
([
int
(
element
)
for
element
in
ilists
],
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register
(
AdvancedSubtensor
)
@mlx_funcify.register
(
AdvancedSubtensor1
)
def
mlx_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
advanced_subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register
(
IncSubtensor
)
@mlx_funcify.register
(
AdvancedIncSubtensor1
)
def
mlx_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
mlx_fn
(
x
,
indices
,
y
):
...
...
@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x
[
indices
]
+=
y
return
x
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
mlx_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
@mlx_funcify.register
(
MakeSlice
)
def
mlx_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -10,17 +10,18 @@ from numba import types
from
numba.core.pythonapi
import
box
import
pytensor.link.numba.dispatch.basic
as
numba_basic
from
pytensor.graph
import
Variabl
e
from
pytensor.graph
import
Typ
e
from
pytensor.link.numba.cache
import
(
compile_numba_function_src
,
)
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.numba.dispatch.compile_ops
import
numba_deepcopy
from
pytensor.link.numba.dispatch.string_codegen
import
create_tuple_string
from
pytensor.tensor
import
TensorType
,
TensorVariable
from
pytensor.tensor
import
TensorType
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
...
...
@@ -28,8 +29,8 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1
,
IncSubtensor
,
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneTypeT
def
slice_new
(
self
,
start
,
stop
,
step
):
...
...
@@ -117,6 +118,15 @@ def numba_deepcopy_slice(x):
return
deepcopy_slice
@register_funcify_default_op_cache_key
(
MakeSlice
)
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
def
subtensor_op_cache_key
(
op
,
**
extra_fields
):
key_parts
=
[
type
(
op
),
tuple
(
extra_fields
.
items
())]
if
hasattr
(
op
,
"idx_list"
):
...
...
@@ -146,36 +156,35 @@ def subtensor_op_cache_key(op, **extra_fields):
def
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
):
"""Create a Python function that assembles and uses an index on an array."""
def
convert_indices
(
indices_iterator
,
entry
):
if
isinstance
(
entry
,
int
):
name
,
var
=
next
(
indices_iterator
)
if
var
.
ndim
==
0
and
isinstance
(
var
.
type
,
TensorType
):
return
f
"{name}.item()"
return
name
def
convert_indices
(
indice_names
,
entry
):
if
indice_names
and
isinstance
(
entry
,
Type
):
return
next
(
indice_names
)
elif
isinstance
(
entry
,
slice
):
return
(
f
"slice({convert_indices(indice
s_iterator
, entry.start)}, "
f
"{convert_indices(indice
s_iterator
, entry.stop)}, "
f
"{convert_indices(indice
s_iterator
, entry.step)})"
f
"slice({convert_indices(indice
_names
, entry.start)}, "
f
"{convert_indices(indice
_names
, entry.stop)}, "
f
"{convert_indices(indice
_names
, entry.step)})"
)
elif
isinstance
(
entry
,
type
(
None
)):
return
"None"
else
:
raise
ValueError
(
f
"Unknown index type: {entry}"
)
raise
ValueError
()
set_or_inc
=
isinstance
(
op
,
IncSubtensor
|
AdvancedIncSubtensor1
|
AdvancedIncSubtensor
)
index_start_idx
=
1
+
int
(
set_or_inc
)
op_indices
=
list
(
node
.
inputs
[
index_start_idx
:])
idx_list
=
op
.
idx_list
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
idx_names
=
[
f
"idx_{i}"
for
i
in
range
(
len
(
op_indices
))]
input_names
=
[
"x"
,
"y"
,
*
idx_names
]
if
set_or_inc
else
[
"x"
,
*
idx_names
]
indices_iterator
=
iter
(
zip
(
idx_names
,
op_indices
))
indices_creation_src
=
tuple
(
convert_indices
(
indices_iterator
,
idx
)
for
idx
in
idx_list
idx_names_iterator
=
iter
(
idx_names
)
indices_creation_src
=
(
tuple
(
convert_indices
(
idx_names_iterator
,
idx
)
for
idx
in
idx_list
)
if
idx_list
else
tuple
(
input_names
[
index_start_idx
:])
)
if
len
(
indices_creation_src
)
==
1
:
...
...
@@ -231,24 +240,20 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key
(
AdvancedIncSubtensor
)
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
if
isinstance
(
op
,
AdvancedSubtensor
):
_x
,
*
index_variables
=
node
.
inputs
_x
,
_y
,
idxs
=
node
.
inputs
[
0
],
None
,
node
.
inputs
[
1
:]
else
:
_x
,
_y
,
*
index_variables
=
node
.
inputs
reconstructed_indices
=
indices_from_subtensor
(
index_variables
,
op
.
idx_list
)
adv_idxs
=
[]
for
i
,
idx
in
enumerate
(
reconstructed_indices
):
if
isinstance
(
idx
,
TensorVariable
):
# This is an advanced tensor index
adv_idxs
.
append
(
{
"axis"
:
i
,
"dtype"
:
idx
.
type
.
dtype
,
"bcast"
:
idx
.
type
.
broadcastable
,
"ndim"
:
idx
.
type
.
ndim
,
}
)
_x
,
_y
,
*
idxs
=
node
.
inputs
adv_idxs
=
[
{
"axis"
:
i
,
"dtype"
:
idx
.
type
.
dtype
,
"bcast"
:
idx
.
type
.
broadcastable
,
"ndim"
:
idx
.
type
.
ndim
,
}
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorType
)
]
must_ignore_duplicates
=
(
isinstance
(
op
,
AdvancedIncSubtensor
)
...
...
@@ -260,10 +265,13 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
)
# Special implementation for integer indices that respects duplicates
if
(
not
must_ignore_duplicates
and
len
(
adv_idxs
)
>=
1
and
all
(
adv_idx
[
"dtype"
]
!=
"bool"
for
adv_idx
in
adv_idxs
)
# Implementation does not support newaxis
and
not
any
(
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
idxs
)
):
return
vector_integer_advanced_indexing
(
op
,
node
,
**
kwargs
)
...
...
@@ -391,6 +399,7 @@ def vector_integer_advanced_indexing(
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
# Index over tuples of raveled advanced indices and update buffer
...
...
@@ -451,90 +460,45 @@ def vector_integer_advanced_indexing(
return x
"""
if
isinstance
(
op
,
AdvancedSubtensor1
|
AdvancedSubtensor
):
x
,
*
i
ndex_variable
s
=
node
.
inputs
x
,
*
i
dx
s
=
node
.
inputs
else
:
x
,
y
,
*
index_variables
=
node
.
inputs
x
,
y
,
*
idxs
=
node
.
inputs
[
out
]
=
node
.
outputs
reconstructed_indices
=
indices_from_subtensor
(
index_variables
,
op
.
idx_list
)
idx_args
=
[
f
"idx{i}"
for
i
in
range
(
len
(
index_variables
))]
var_to_arg
=
dict
(
zip
(
index_variables
,
idx_args
))
idxs
=
[]
def
get_idx_str
(
val
,
is_slice_component
=
False
):
if
val
is
None
:
return
"None"
if
isinstance
(
val
,
Variable
)
and
val
in
var_to_arg
:
arg
=
var_to_arg
[
val
]
if
val
.
ndim
==
0
and
is_slice_component
:
return
f
"{arg}.item()"
return
arg
raise
ValueError
(
f
"Unexpected index value: {val}"
)
for
idx
in
reconstructed_indices
:
if
isinstance
(
idx
,
slice
):
start
=
get_idx_str
(
idx
.
start
,
is_slice_component
=
True
)
stop
=
get_idx_str
(
idx
.
stop
,
is_slice_component
=
True
)
step
=
get_idx_str
(
idx
.
step
,
is_slice_component
=
True
)
idxs
.
append
(
f
"slice({start}, {stop}, {step})"
)
else
:
# It's a direct index variable
idxs
.
append
(
get_idx_str
(
idx
,
is_slice_component
=
False
))
adv_indices_pos
=
tuple
(
i
for
i
,
idx
in
enumerate
(
reconstructed_indices
)
if
not
isinstance
(
idx
,
slic
e
)
i
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorTyp
e
)
)
assert
adv_indices_pos
# Otherwise it's just basic indexing
basic_indices_pos
=
tuple
(
i
for
i
,
idx
in
enumerate
(
reconstructed_indices
)
if
isinstance
(
idx
,
slic
e
)
i
for
i
,
idx
in
enumerate
(
idxs
)
if
not
isinstance
(
idx
.
type
,
TensorTyp
e
)
)
explicit_basic_indices_pos
=
(
*
basic_indices_pos
,
*
range
(
len
(
idxs
),
x
.
type
.
ndim
))
# Create index signature for generated function: "idx0, idx1, idx2, ..."
idx_signature
=
", "
.
join
(
idx_args
)
# Create index signature and split them among basic and advanced
idx_signature
=
", "
.
join
(
f
"idx{i}"
for
i
in
range
(
len
(
idxs
)))
adv_indices
=
[
f
"idx{i}"
for
i
in
adv_indices_pos
]
basic_indices
=
[
f
"idx{i}"
for
i
in
basic_indices_pos
]
# String representations of advanced and basic indices for codegen
adv_indices
=
[
idxs
[
i
]
for
i
in
adv_indices_pos
]
basic_indices
=
[
idxs
[
i
]
for
i
in
basic_indices_pos
]
# Define transpose axis so that advanced indexing dims are on the front
adv_axis_front_order
=
(
*
adv_indices_pos
,
*
explicit_basic_indices_pos
)
adv_axis_front_transpose_needed
=
adv_axis_front_order
!=
tuple
(
range
(
x
.
ndim
))
adv_idx_ndim
=
max
(
idxs
[
i
]
.
ndim
for
i
in
adv_indices_pos
)
to_tuple
=
create_tuple_string
# alias to make code more readable below
# Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices
=
", "
.
join
(
(
*
((
":"
,)
*
len
(
adv_indices
)),
*
basic_indices
)
)
#
Compute number of dimensions in advanced indices (after broadcasting)
if
len
(
adv_indices_pos
)
==
1
:
adv_idx
=
reconstructed_indices
[
adv_indices_pos
[
0
]]
adv_idx_ndim
=
adv_idx
.
ndim
# type: ignore[union-attr]
#
Position of the first advanced index dimension after indexing the array
if
(
np
.
diff
(
adv_indices_pos
)
>
1
)
.
any
()
:
# If not consecutive, it's always at the front
out_adv_axis_pos
=
0
else
:
# Multiple advanced indices - use max ndim (broadcast result ndim)
adv_idx_ndim
=
max
(
reconstructed_indices
[
i
]
.
ndim
for
i
in
adv_indices_pos
)
# type: ignore[union-attr]
# Determine output position of advanced indexed dimensions
# If advanced indices are consecutive, they go in the first advanced index position
# Otherwise they go at the beginning
if
adv_indices_pos
==
tuple
(
range
(
adv_indices_pos
[
0
],
adv_indices_pos
[
-
1
]
+
1
)):
# Consecutive - advanced dims will be at position of first advanced index
# Otherwise wherever the first advanced index is located
out_adv_axis_pos
=
adv_indices_pos
[
0
]
else
:
# Non-consecutive - advanced dims go at the front
out_adv_axis_pos
=
0
# Include trailing dimensions not covered by explicit indices
explicit_basic_indices_pos
=
(
*
basic_indices_pos
,
*
range
(
len
(
reconstructed_indices
),
x
.
type
.
ndim
),
)
# Compute transpose to move advanced indexed dims to the front
adv_axis_front_order
=
(
*
adv_indices_pos
,
*
explicit_basic_indices_pos
)
adv_axis_front_transpose_needed
=
adv_axis_front_order
!=
tuple
(
range
(
x
.
type
.
ndim
))
# Compute basic indices with "None" slices for dimensions that will be indexed by advanced indices
basic_indices_with_none_slices
=
", "
.
join
(
":"
for
_
in
range
(
len
(
adv_indices_pos
))
)
+
(
", "
+
", "
.
join
(
basic_indices
)
if
basic_indices
else
""
)
to_tuple
=
create_tuple_string
# alias to make code more readable below
if
isinstance
(
op
,
AdvancedSubtensor1
|
AdvancedSubtensor
):
# Define transpose axis on the output to restore original meaning
...
...
@@ -593,8 +557,7 @@ def vector_integer_advanced_indexing(
else
:
# Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim
=
x
[
tuple
(
reconstructed_indices
)]
.
type
.
ndim
indexed_ndim
=
x
[
tuple
(
idxs
)]
.
type
.
ndim
y_expand_dims
=
[
":"
]
*
y
.
type
.
ndim
y_implicit_dims
=
range
(
indexed_ndim
-
y
.
type
.
ndim
)
for
axis
in
y_implicit_dims
:
...
...
pytensor/link/pytorch/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceType
def
check_negative_steps
(
indices
):
...
...
@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return
subtensor
@pytorch_funcify.register
(
MakeSlice
)
def
pytorch_funcify_makeslice
(
op
,
**
kwargs
):
def
makeslice
(
start
,
stop
,
step
):
# Torch does not like numpy integers in indexing slices
return
slice
(
None
if
start
is
None
else
int
(
start
),
None
if
stop
is
None
else
int
(
stop
),
None
if
step
is
None
else
int
(
step
),
)
return
makeslice
@pytorch_funcify.register
(
AdvancedSubtensor1
)
@pytorch_funcify.register
(
AdvancedSubtensor
)
def
pytorch_funcify_AdvSubtensor
(
op
,
node
,
**
kwargs
):
def
advsubtensor
(
x
,
*
indices
):
indices
=
indices_from_subtensor
(
indices
,
op
.
idx_list
)
check_negative_steps
(
indices
)
return
x
[
indices
]
...
...
@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register
(
AdvancedIncSubtensor
)
@pytorch_funcify.register
(
AdvancedIncSubtensor1
)
def
pytorch_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
op
.
idx_list
inplace
=
op
.
inplace
ignore_duplicates
=
getattr
(
op
,
"ignore_duplicates"
,
False
)
if
op
.
set_instead_of_inc
:
def
adv_set_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
def
adv_set_subtensor
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
...
@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif
ignore_duplicates
:
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
...
@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
adv_inc_subtensor_no_duplicates
else
:
if
any
(
isinstance
(
entry
,
slice
)
for
entry
in
idx_list
):
if
any
(
isinstance
(
idx
.
type
,
SliceType
)
for
idx
in
node
.
inputs
[
2
:]
):
raise
NotImplementedError
(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def
adv_inc_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
# Not needed because slices aren't supported in this path
def
adv_inc_subtensor
(
x
,
y
,
*
indices
):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if
not
inplace
:
x
=
x
.
clone
()
...
...
pytensor/scan/rewriting.py
浏览文件 @
cc6bed1a
...
...
@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from
pytensor.tensor.subtensor
import
(
IncSubtensor
,
Subtensor
,
basic_subtensor
,
get_canonical_form_slice
,
get_idx_list
,
get_slice_elements
,
set_subtensor
,
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if
not
(
isinstance
(
op
,
IncSubtensor
)
and
op
.
set_instead_of_inc
and
op
.
idx_list
==
(
slice
(
None
,
0
),)
and
op
.
idx_list
==
[
slice
(
None
,
ps
.
int64
)]
):
return
False
...
...
@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else
:
# 2.3.1 extract idx list of subtensor
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3.2 extract the begin/end of the first dimension
if
i
>=
op_info
.
n_mit_mot
:
...
...
@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break
else
:
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
store_steps
[
i
]
=
0
break
if
isinstance
(
this_slice
[
0
],
slice
):
start
=
this_slice
[
0
]
.
start
...
...
@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,
*
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
fslice
,
*
old_slices
[
1
:])
subtens
=
Subtensor
(
nw_slice
)
# slice inputs
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
...
...
@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
nw_slice
=
(
sanitize
(
position
),
*
old_slices
[
1
:])
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
*
nw_slice
)
subtens
=
Subtensor
(
nw_slice
)
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
...
...
pytensor/tensor/basic.py
浏览文件 @
cc6bed1a
...
...
@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.type
import
HasShape
,
Type
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
...
...
@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and
len
(
v
.
owner
.
op
.
idx_list
)
==
1
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op
=
owner
.
op
idx_list
=
op
.
idx_list
idx
=
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
cc6bed1a
...
...
@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
...
...
@@ -237,15 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
return
False
# Parse indices
if
isinstance
(
subtensor_op
,
Subtensor
|
AdvancedSubtensor
):
if
isinstance
(
subtensor_op
,
Subtensor
):
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
subtensor_op
.
idx_list
)
else
:
indices
=
node
.
inputs
[
1
:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
for
idx
in
indices
):
return
False
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
indices
):
return
False
# Check that indexing does not act on support dims
batch_ndims
=
rv_op
.
batch_ndim
(
rv_node
)
...
...
@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices
[
batch_ndims
:],
)
for
idx
in
supp_indices
:
if
idx
!=
slice
(
None
):
if
not
(
isinstance
(
idx
.
type
,
SliceType
)
and
all
(
isinstance
(
i
.
type
,
NoneTypeT
)
for
i
in
idx
.
owner
.
inputs
)
):
return
False
n_discarded_idxs
=
len
(
supp_indices
)
indices
=
indices
[:
-
n_discarded_idxs
]
...
...
@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim
if
curr_dim
in
bcast_param_dims
:
# Slice indexing, keep degenerate dim by none-slicing
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
)
or
isinstance
(
idx
.
type
,
SliceType
)
:
batch_indices
.
append
(
slice
(
None
))
# Integer indexing, drop degenerate dim by 0-indexing
else
:
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
cc6bed1a
...
...
@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
)
from
pytensor.graph.traversal
import
ancestors
from
pytensor.graph.utils
import
InconsistencyError
,
get_variable_trace_string
from
pytensor.scalar
import
ScalarType
from
pytensor.tensor.basic
import
(
MakeVector
,
as_tensor_variable
,
...
...
@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
if
isinstance
(
var
.
owner
.
op
,
Shape_i
):
return
(
var
.
owner
.
op
.
i
==
i
)
and
(
var
.
owner
.
inputs
[
0
]
==
x
)
# type: ignore
# Match Subtensor((
int,))(Shape(input), i) - single integer index into shape
# Match Subtensor((
ScalarType,))(Shape(input), i)
if
isinstance
(
var
.
owner
.
op
,
Subtensor
):
idx_entry
=
(
var
.
owner
.
op
.
idx_list
[
0
]
if
len
(
var
.
owner
.
op
.
idx_list
)
==
1
else
None
)
return
(
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len
(
var
.
owner
.
op
.
idx_list
)
==
1
and
isinstance
(
idx_entry
,
int
)
and
isinstance
(
var
.
owner
.
op
.
idx_list
[
0
],
ScalarType
)
# Check we are indexing on the shape of x
and
var
.
owner
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
var
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
cc6bed1a
import
itertools
import
sys
import
warnings
import
numpy
as
np
...
...
@@ -16,7 +15,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter
,
)
from
pytensor.raise_op
import
Assert
from
pytensor.scalar
import
Add
,
ScalarConstant
from
pytensor.scalar
import
Add
,
ScalarConstant
,
ScalarType
from
pytensor.scalar
import
constant
as
scalar_constant
from
pytensor.tensor.basic
import
(
Alloc
,
...
...
@@ -32,7 +31,6 @@ from pytensor.tensor.basic import (
full
,
get_scalar_constant_value
,
get_underlying_scalar_constant_value
,
moveaxis
,
register_infer_shape
,
switch
,
)
...
...
@@ -74,11 +72,10 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1
,
IncSubtensor
,
Subtensor
,
_non_consecutive_adv_indexing
,
advanced_inc_subtensor1
,
advanced_subtensor
,
advanced_subtensor1
,
as_index_constant
,
basic_subtensor
,
get_canonical_form_slice
,
get_constant_idx
,
get_idx_list
,
...
...
@@ -87,6 +84,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -156,10 +154,8 @@ def transform_take(a, indices, axis):
if
len
(
shape_parts
)
>
1
:
shape
=
pytensor
.
tensor
.
concatenate
(
shape_parts
)
elif
len
(
shape_parts
)
==
1
:
shape
=
shape_parts
[
0
]
else
:
shape
=
()
shape
=
shape_parts
[
0
]
ndim
=
a
.
ndim
+
indices
.
ndim
-
1
...
...
@@ -167,11 +163,23 @@ def transform_take(a, indices, axis):
def
is_full_slice
(
x
):
warnings
.
warn
(
"The function is deprecated, use x==slice(None) instead."
,
DeprecationWarning
,
)
return
x
==
slice
(
None
)
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if
isinstance
(
x
,
slice
):
return
x
==
slice
(
None
)
if
isinstance
(
x
,
Variable
)
and
isinstance
(
x
.
type
,
SliceType
):
if
x
.
owner
is
None
:
if
isinstance
(
x
,
Constant
):
return
x
.
data
==
slice
(
None
)
else
:
# Root slice variable
return
False
# Symbolic MakeSlice
# Ignores start = 0, step = 1 cases
return
all
(
isinstance
(
i
.
type
,
NoneTypeT
)
for
i
in
x
.
owner
.
inputs
)
return
False
def
get_advsubtensor_axis
(
indices
):
...
...
@@ -186,13 +194,13 @@ def get_advsubtensor_axis(indices):
found_idx
=
False
axis
=
0
for
idx
in
indices
:
if
not
found_idx
and
i
dx
==
slice
(
None
):
if
not
found_idx
and
i
s_full_slice
(
idx
):
# Preceding full slices
axis
+=
1
elif
found_idx
and
not
i
dx
==
slice
(
None
):
elif
found_idx
and
not
i
s_full_slice
(
idx
):
# We don't handle multiple indices
return
elif
found_idx
and
i
dx
==
slice
(
None
):
elif
found_idx
and
i
s_full_slice
(
idx
):
# Trailing full slices
continue
else
:
...
...
@@ -219,8 +227,9 @@ def local_replace_AdvancedSubtensor(fgraph, node):
if
not
isinstance
(
node
.
op
,
AdvancedSubtensor
):
return
indexed_var
,
*
index_variables
=
node
.
inputs
indices
=
indices_from_subtensor
(
index_variables
,
node
.
op
.
idx_list
)
indexed_var
=
node
.
inputs
[
0
]
indices
=
node
.
inputs
[
1
:]
axis
=
get_advsubtensor_axis
(
indices
)
if
axis
is
None
or
indices
[
axis
]
.
dtype
==
"bool"
:
...
...
@@ -244,8 +253,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
# `AdvancedIncSubtensor1` does not ignore duplicate index values
return
res
,
val
,
*
index_variables
=
node
.
inputs
indices
=
indices_from_subtensor
(
index_variables
,
node
.
op
.
idx_list
)
res
=
node
.
inputs
[
0
]
val
=
node
.
inputs
[
1
]
indices
=
node
.
inputs
[
2
:]
axis
=
get_advsubtensor_axis
(
indices
)
...
...
@@ -418,7 +428,11 @@ def local_subtensor_merge(fgraph, node):
merged_slices
+=
slices1
[
pos_1
:]
merged_slices
=
tuple
(
as_index_constant
(
s
)
for
s
in
merged_slices
)
out
=
basic_subtensor
(
x
,
*
merged_slices
)
subtens
=
Subtensor
(
merged_slices
)
sl_ins
=
get_slice_elements
(
merged_slices
,
lambda
x
:
isinstance
(
x
,
Variable
))
# Do not call make_node for test_value
out
=
subtens
(
x
,
*
sl_ins
)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
...
...
@@ -449,8 +463,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim
=
[]
node_inputs_idx
=
1
for
dim
,
elem
in
enumerate
(
idx
):
if
isinstance
(
elem
,
int
):
# The idx is a integer position.
if
isinstance
(
elem
,
ScalarType
):
# The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index
=
node
.
inputs
[
node_inputs_idx
]
if
isinstance
(
dim_index
,
ScalarConstant
):
dim_index
=
dim_index
.
value
...
...
@@ -462,6 +477,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
elif
isinstance
(
elem
,
slice
):
if
elem
!=
slice
(
None
):
return
elif
isinstance
(
elem
,
int
|
np
.
integer
):
if
elem
in
(
0
,
-
1
)
and
node
.
inputs
[
0
]
.
broadcastable
[
dim
]:
remove_dim
.
append
(
dim
)
else
:
raise
TypeError
(
"case not expected"
)
...
...
@@ -488,29 +506,26 @@ def local_subtensor_inc_subtensor(fgraph, node):
if
not
x
.
owner
.
op
.
set_instead_of_inc
:
return
x_inc
,
y_inc
,
*
inc_index_variables
=
x
.
owner
.
inputs
_sub_x
,
*
sub_index_variables
=
node
.
inputs
if
(
inc_index_variables
==
sub_index_variables
and
x
.
owner
.
op
.
idx_list
==
node
.
op
.
idx_list
):
if
x
.
owner
.
inputs
[
2
:]
==
node
.
inputs
[
1
:]
and
tuple
(
x
.
owner
.
op
.
idx_list
)
==
tuple
(
node
.
op
.
idx_list
):
out
=
node
.
outputs
[
0
]
y
=
x
.
owner
.
inputs
[
1
]
# If the dtypes differ, cast y into x.dtype
if
x
.
dtype
!=
y
_inc
.
dtype
:
y
_inc
=
y_inc
.
astype
(
x
.
dtype
)
if
x
.
dtype
!=
y
.
dtype
:
y
=
y
.
astype
(
x
.
dtype
)
if
(
out
.
type
.
dtype
==
y
_inc
.
type
.
dtype
and
out
.
type
.
broadcastable
==
y
_inc
.
type
.
broadcastable
out
.
type
.
dtype
==
y
.
type
.
dtype
and
out
.
type
.
broadcastable
==
y
.
type
.
broadcastable
):
# if x[idx] and y have the same type, directly return y
return
[
y
_inc
]
return
[
y
]
else
:
# The difference is related to broadcasting pattern
assert
out
.
broadcastable
!=
y
_inc
.
broadcastable
assert
out
.
broadcastable
!=
y
.
broadcastable
# We have to alloc y to the shape of x[idx]
x_subtensor
=
node
.
op
(
x
_inc
,
*
inc_index_variables
)
return
[
alloc
(
y
_inc
,
*
x_subtensor
.
shape
)]
x_subtensor
=
node
.
op
(
x
.
owner
.
inputs
[
0
],
*
x
.
owner
.
inputs
[
2
:]
)
return
[
alloc
(
y
,
*
x_subtensor
.
shape
)]
else
:
return
...
...
@@ -814,9 +829,9 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
raise
ValueError
(
"slice1 should be of type `slice`"
)
# Simple case where one of the slices is useless
if
slice1
==
slice
(
None
):
if
is_full_slice
(
slice1
):
return
slice2
elif
slice2
==
slice
(
None
):
elif
is_full_slice
(
slice2
):
return
slice1
sl1
,
reverse1
=
get_canonical_form_slice
(
slice1
,
len1
)
...
...
@@ -1075,7 +1090,6 @@ compile.optdb.register(
def
local_inplace_AdvancedIncSubtensor
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
AdvancedIncSubtensor
)
and
not
node
.
op
.
inplace
:
new_op
=
type
(
node
.
op
)(
node
.
op
.
idx_list
,
inplace
=
True
,
set_instead_of_inc
=
node
.
op
.
set_instead_of_inc
,
ignore_duplicates
=
node
.
op
.
ignore_duplicates
,
...
...
@@ -1262,7 +1276,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
"""
if
isinstance
(
node
.
op
,
IncSubtensor
|
AdvancedIncSubtensor
|
AdvancedIncSubtensor1
):
x
,
y
,
*
index_variables
=
node
.
inputs
x
=
node
.
inputs
[
0
]
y
=
node
.
inputs
[
1
]
i
=
node
.
inputs
[
2
:]
if
y
.
owner
is
not
None
and
isinstance
(
y
.
owner
.
op
,
Alloc
):
# `z` is the input of the Alloc op, i.e. at.alloc(z, <shape>)
...
...
@@ -1281,11 +1297,11 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
# Get the subtensor of `x` indexed by `i` in order to compare
# shapes later.
if
isinstance
(
node
.
op
,
IncSubtensor
):
xi
=
Subtensor
(
node
.
op
.
idx_list
)(
x
,
*
i
ndex_variables
)
xi
=
Subtensor
(
node
.
op
.
idx_list
)(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor
):
xi
=
AdvancedSubtensor
(
node
.
op
.
idx_list
)(
x
,
*
index_variables
)
xi
=
advanced_subtensor
(
x
,
*
i
)
elif
isinstance
(
node
.
op
,
AdvancedIncSubtensor1
):
xi
=
advanced_subtensor1
(
x
,
*
i
ndex_variables
)
xi
=
advanced_subtensor1
(
x
,
*
i
)
else
:
raise
Exception
(
"Should never happen!"
)
...
...
@@ -1345,7 +1361,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
msg
=
"`x[i]` and `y` do not have the same shape."
z
=
Assert
(
msg
)(
z
,
*
cond
)
r
=
node
.
op
(
x
,
z
,
*
i
ndex_variables
)
r
=
node
.
op
(
x
,
z
,
*
i
)
# Copy over stacktrace from previous output, since
# we don't expect problems when removing the intermediate
# alloc operation and so we still want to point at the line
...
...
@@ -1477,7 +1493,8 @@ def local_uint_constant_indices(fgraph, node):
x
,
*
indices
=
node
.
inputs
y
=
None
new_indices
=
list
(
indices_from_subtensor
(
indices
,
node
.
op
.
idx_list
))
idx_list
=
getattr
(
node
.
op
,
"idx_list"
,
None
)
new_indices
=
list
(
indices_from_subtensor
(
indices
,
idx_list
))
has_new_index
=
False
for
i
,
index
in
enumerate
(
new_indices
):
...
...
@@ -1527,7 +1544,14 @@ def local_uint_constant_indices(fgraph, node):
if
not
has_new_index
:
return
False
new_indices
=
get_slice_elements
(
new_indices
)
if
isinstance
(
op
,
Subtensor
|
IncSubtensor
):
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
props
=
op
.
_props_dict
()
props
[
"idx_list"
]
=
new_indices
op
=
type
(
op
)(
**
props
)
# Basic index Ops don't expect slices, but the respective start/step/stop
new_indices
=
get_slice_elements
(
new_indices
)
new_args
=
(
x
,
*
new_indices
)
if
y
is
None
else
(
x
,
y
,
*
new_indices
)
new_out
=
op
(
*
new_args
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)
...
...
@@ -1587,18 +1611,27 @@ def local_blockwise_inc_subtensor(fgraph, node):
core_op
=
node
.
op
.
core_op
x
,
y
,
*
idxs
=
node
.
inputs
[
out
]
=
node
.
outputs
advanced
=
isinstance
(
core_op
,
AdvancedIncSubtensor
)
if
advanced
and
any
(
idx
.
type
.
dtype
==
"bool"
for
idx
in
idxs
):
# Get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
return
None
if
isinstance
(
core_op
,
AdvancedIncSubtensor
):
if
any
(
(
# Blockwise requires all inputs to be tensors so it is not possible
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
# are separated by basic indices
isinstance
(
idx
,
SliceType
|
NoneTypeT
)
# Also get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
or
(
idx
.
type
.
dtype
==
"bool"
)
)
for
idx
in
idxs
):
return
None
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
idxs_core_ndim
=
[
len
(
inp_sig
)
for
inp_sig
in
node
.
op
.
inputs_sig
[
2
:]]
max_idx_core_ndim
=
max
(
idxs_core_ndim
,
default
=
0
)
# Broadcast buffer to batch_shape
#
Step 1.
Broadcast buffer to batch_shape
if
x
.
type
.
broadcastable
!=
out
.
type
.
broadcastable
:
batch_shape
=
[
1
]
*
batch_ndim
for
inp
in
node
.
inputs
:
...
...
@@ -1615,61 +1648,58 @@ def local_blockwise_inc_subtensor(fgraph, node):
x
=
broadcast_to
(
x
,
(
*
batch_shape
,
*
x
.
shape
[
batch_ndim
:]))
assert
x
.
type
.
broadcastable
==
out
.
type
.
broadcastable
# Massage indices so they respect blockwise semantics while using regular indexing
core_idxs
=
[]
for
idx_entry
in
core_op
.
idx_list
:
if
isinstance
(
idx_entry
,
slice
):
# Squeeze away dummy dimensions so we can convert to slice
new_entries
=
[
None
,
None
,
None
]
for
i
,
slice_idx_entry
in
enumerate
(
(
idx_entry
.
start
,
idx_entry
.
stop
,
idx_entry
.
step
)
):
if
slice_idx_entry
is
None
:
continue
else
:
new_entries
[
i
]
=
new_entry
=
idxs
[
slice_idx_entry
]
.
squeeze
()
if
new_entry
.
ndim
>
0
:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return
None
squeezed_index
=
slice
(
*
new_entries
)
else
:
if
advanced
:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
squeezed_index
=
_squeeze_left
(
shape_padright
(
idxs
[
idx_entry
],
max_idx_core_ndim
-
idxs_core_ndim
[
idx_entry
]
),
stop_at_dim
=
batch_ndim
,
)
# Step 2. Massage indices so they respect blockwise semantics
if
isinstance
(
core_op
,
IncSubtensor
):
# For basic IncSubtensor there are two cases:
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
# in case we can end up with a basic IncSubtensor again
core_idxs
=
[]
counter
=
0
for
idx
in
core_op
.
idx_list
:
if
isinstance
(
idx
,
slice
):
# Squeeze away dummy dimensions so we can convert to slice
new_entries
=
[
None
,
None
,
None
]
for
i
,
entry
in
enumerate
((
idx
.
start
,
idx
.
stop
,
idx
.
step
)):
if
entry
is
None
:
continue
else
:
new_entries
[
i
]
=
new_entry
=
idxs
[
counter
]
.
squeeze
()
counter
+=
1
if
new_entry
.
ndim
>
0
:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return
None
core_idxs
.
append
(
slice
(
*
new_entries
))
else
:
# For basic IncSubtensor integers indices can be used as is, but we try to squeeze away dummy
# batch dimensions in case we can end up with a basic IncSubtensor again
squeezed_index
=
_squeeze_left
(
idxs
[
idx_entry
])
core_idxs
.
append
(
squeezed_index
)
core_idxs
.
append
(
_squeeze_left
(
idxs
[
counter
]))
counter
+=
1
else
:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
core_idxs
=
[
_squeeze_left
(
shape_padright
(
idx
,
max_idx_core_ndim
-
idx_core_ndim
),
stop_at_dim
=
batch_ndim
,
)
for
idx
,
idx_core_ndim
in
zip
(
idxs
,
idxs_core_ndim
)
]
#
Create new indices for the batch dimensions
has_batched_indices
=
not
all
(
#
Step 3. Create new indices for the new batch dimension of x
if
not
all
(
all
(
idx
.
type
.
broadcastable
[:
batch_ndim
])
for
idx
in
idxs
if
not
isinstance
(
idx
,
slice
)
)
if
has_batched_indices
:
# If indices have batch dimensions, we need to align them element-wise with the respective batch dimensions of x
# We achieve this by creating `arange` indices and adding expand_dims for correct broadcasting.
# Example:
# x = pt.zeros(5); idx = [0, 1, 0]; out = x[idx].set(y)
# batch_x = pt.zeros((2, 5)); batch_idx = [[0, 1, 0], [1, 1, 2]]
# batch_out = batch_x[[0, 1][:, None], batch_idx].set(y)
# If instead batch_x = pt.zeros((2, 2, 5))
# batch_out = batch_x[[0, 1][:, None, None], [0, 1][None, 1, None], batch_idx]
# Note: For simplicity we use arange for all batch dimensions of x,
# even if not all may have corresponding batch index dimensions
):
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
# (i.e., they broadcast)
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
# even if not all batch dimensions have corresponding batch indices.
batch_slices
=
[
shape_padright
(
arange
(
x_batch_shape
,
dtype
=
"int64"
),
n
)
for
(
x_batch_shape
,
n
)
in
zip
(
...
...
@@ -1685,49 +1715,29 @@ def local_blockwise_inc_subtensor(fgraph, node):
new_idxs
=
(
*
batch_slices
,
*
core_idxs
)
x_view
=
x
[
new_idxs
]
# Introduce any implicit expand_dims on core dimension of y
#
Step 4.
Introduce any implicit expand_dims on core dimension of y
missing_y_core_ndim
=
x_view
.
type
.
ndim
-
y
.
type
.
ndim
implicit_axes
=
tuple
(
range
(
batch_ndim
,
batch_ndim
+
missing_y_core_ndim
))
y
=
expand_dims
(
y
,
implicit_axes
)
# Transpose y if needed
if
has_batched_indices
:
# By introducing arange slices we may caused a transposition of the advanced group to the front
# If this was not already happening in the core graph, we'll need to transpose y to align it correctly
if
max_idx_core_ndim
and
not
(
advanced
and
_non_consecutive_adv_indexing
(
core_idxs
)
):
integer_pos
=
[
i
for
i
,
entry
in
enumerate
(
core_op
.
idx_list
)
if
isinstance
(
entry
,
int
)
]
slice_pos
=
[
i
for
i
,
entry
in
enumerate
(
core_op
.
idx_list
)
if
isinstance
(
entry
,
slice
)
]
if
slice_pos
and
integer_pos
and
(
slice_pos
[
0
]
<
integer_pos
[
-
1
]):
y
=
moveaxis
(
y
,
[
batch_ndim
+
integer_pos
[
0
]
+
i
for
i
in
range
(
max_idx_core_ndim
)],
[
batch_ndim
+
i
for
i
in
range
(
max_idx_core_ndim
)],
)
else
:
# Conversely if we tried to use `slice(None)` for the batch dimensions but there was already transposition
# in the core case, we'll need to move the batch slices of y to after the advanced indexing group
if
advanced
and
_non_consecutive_adv_indexing
(
core_idxs
):
y
=
moveaxis
(
y
,
[
i
for
i
in
range
(
batch_ndim
)],
# noqa: C416
[
max_idx_core_ndim
+
i
for
i
in
range
(
batch_ndim
)],
)
# Remove useless left-batch dimensions of y (if any)
y
=
_squeeze_left
(
y
,
stop_at_dim
=
batch_ndim
)
if
core_op
.
set_instead_of_inc
:
new_out
=
x
[
new_idxs
]
.
set
(
y
)
y
=
_squeeze_left
(
expand_dims
(
y
,
implicit_axes
),
stop_at_dim
=
batch_ndim
)
if
isinstance
(
core_op
,
IncSubtensor
):
# Check if we can still use a basic IncSubtensor
if
isinstance
(
x_view
.
owner
.
op
,
Subtensor
):
new_props
=
core_op
.
_props_dict
()
new_props
[
"idx_list"
]
=
x_view
.
owner
.
op
.
idx_list
new_core_op
=
type
(
core_op
)(
**
new_props
)
symbolic_idxs
=
x_view
.
owner
.
inputs
[
1
:]
new_out
=
new_core_op
(
x
,
y
,
*
symbolic_idxs
)
else
:
# We need to use AdvancedSet/IncSubtensor
if
core_op
.
set_instead_of_inc
:
new_out
=
x
[
new_idxs
]
.
set
(
y
)
else
:
new_out
=
x
[
new_idxs
]
.
inc
(
y
)
else
:
new_out
=
x
[
new_idxs
]
.
inc
(
y
)
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
symbolic_idxs
=
x_view
.
owner
.
inputs
[
1
:]
new_out
=
core_op
(
x
,
y
,
*
symbolic_idxs
)
copy_stack_trace
(
out
,
new_out
)
return
[
new_out
]
...
...
@@ -1744,12 +1754,10 @@ def bool_idx_to_nonzero(fgraph, node):
else
:
x
,
y
,
*
idxs
=
node
.
inputs
idxs
=
indices_from_subtensor
(
idxs
,
node
.
op
.
idx_list
)
bool_pos
=
{
i
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
,
TensorVariable
)
and
idx
.
dtype
==
"bool"
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
)
}
if
not
bool_pos
:
...
...
@@ -1763,13 +1771,9 @@ def bool_idx_to_nonzero(fgraph, node):
new_idxs
.
append
(
idx
)
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
new_out
=
x
[
tuple
(
new_idxs
)]
new_out
=
node
.
op
(
x
,
*
new_idxs
)
else
:
new_out
=
(
x
[
tuple
(
new_idxs
)]
.
set
(
y
)
if
node
.
op
.
set_instead_of_inc
else
x
[
tuple
(
new_idxs
)]
.
inc
(
y
)
)
new_out
=
node
.
op
(
x
,
y
,
*
new_idxs
)
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
...
...
@@ -1818,8 +1822,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
):
return
None
x
,
y
,
*
idx_variables
=
diag_x
.
owner
.
inputs
idxs
=
indices_from_subtensor
(
idx_variables
,
diag_x
.
owner
.
op
.
idx_list
)
x
,
y
,
*
idxs
=
diag_x
.
owner
.
inputs
if
not
(
x
.
type
.
ndim
>=
2
...
...
@@ -1835,7 +1838,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
# Check all non-axis indices are full slices
axis
=
{
op
.
axis1
,
op
.
axis2
}
if
not
all
(
i
dx
==
slice
(
None
)
for
i
,
idx
in
enumerate
(
idxs
)
if
i
not
in
axis
):
if
not
all
(
i
s_full_slice
(
idx
)
for
i
,
idx
in
enumerate
(
idxs
)
if
i
not
in
axis
):
return
None
# Check axis indices are arange we would expect from setting on the diagonal
...
...
pytensor/tensor/rewriting/subtensor_lift.py
浏览文件 @
cc6bed1a
...
...
@@ -8,6 +8,7 @@ from pytensor import Variable
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
,
FunctionGraph
,
node_rewriter
,
vectorize_graph
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
copy_stack_trace
from
pytensor.scalar
import
basic
as
ps
from
pytensor.tensor.basic
import
(
Alloc
,
Join
,
...
...
@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize
,
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.subtensor
import
register_useless
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
,
register_useless
from
pytensor.tensor.shape
import
(
Shape
,
SpecifyShape
,
...
...
@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
)
->
bool
:
if
isinstance
(
axis
,
int
):
axis
=
(
axis
,)
return
any
(
ax
<
len
(
idxs
)
and
not
i
dxs
[
ax
]
==
slice
(
None
)
for
ax
in
axis
)
return
any
(
ax
<
len
(
idxs
)
and
not
i
s_full_slice
(
idxs
[
ax
]
)
for
ax
in
axis
)
def
_lift_subtensor_non_axis
(
...
...
@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable
:
TensorVariable
,
)
->
None
|
list
[
TensorVariable
]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
dx
==
slice
(
None
)]
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
s_full_slice
(
idx
)]
if
len
(
real_indices
)
>
1
and
variable
.
type
.
ndim
>
1
:
# Split the subtensor
idx_to_keep
=
idx_tuple
[
axis
]
...
...
@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if
len
(
idx_tuple
)
>
batch_ndim
:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices
,
core_indices
=
idx_tuple
[:
batch_ndim
],
idx_tuple
[
batch_ndim
:]
if
all
(
i
dx
==
slice
(
None
)
for
idx
in
batch_indices
):
if
all
(
i
s_full_slice
(
idx
)
for
idx
in
batch_indices
):
# No batch indices, nothing to do
return
None
elem_with_batch_indices
=
elem
[
batch_indices
]
...
...
@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict
=
False
,
)
):
if
dim_idx
==
slice
(
None
):
if
is_full_slice
(
dim_idx
):
# Full slice can be safely applied to all inputs
continue
...
...
@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if
i
in
expanded_axes
:
if
isinstance
(
idx_item
,
slice
):
# Slice could be keeping or dropping this dimension
if
i
dx_item
==
slice
(
None
):
if
i
s_full_slice
(
idx_item
):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
...
...
@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
if
any
(
isinstance
(
index
,
slice
)
for
index
in
indices
):
if
any
(
isinstance
(
index
,
slice
)
or
isinstance
(
getattr
(
index
,
"type"
,
None
),
SliceType
)
for
index
in
indices
):
return
False
new_obj_arg
=
obj_arg
[
indices
]
...
...
@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
(
idx
,)
=
idxs
if
isinstance
(
idx
,
int
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
ps
.
ScalarType
|
TensorType
):
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
.
is_super
(
old_idx
)
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
Variable
):
if
isinstance
(
idx
,
int
|
np
.
integer
):
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
try
:
v
=
get_underlying_scalar_constant_value
(
...
...
@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
except
NotScalarConstantError
:
return
False
assert
idx_val
!=
np
.
newaxis
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
...
...
@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return
None
x
,
*
adv_index_vars
=
adv_subtensor
.
owner
.
inputs
adv_idxs
=
indices_from_subtensor
(
adv_index_vars
,
adv_subtensor
.
owner
.
op
.
idx_list
)
x
,
*
adv_idxs
=
adv_subtensor
.
owner
.
inputs
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if
(
not
all
(
(
(
isinstance
(
adv_idx
,
TensorVariable
)
and
adv_idx
.
type
.
dtype
!=
"bool"
)
or
(
isinstance
(
adv_idx
,
slice
)
and
adv_idx
==
slice
(
None
))
)
for
adv_idx
in
adv_idxs
if
any
(
(
isinstance
(
adv_idx
.
type
,
NoneTypeT
)
or
(
isinstance
(
adv_idx
.
type
,
TensorType
)
and
adv_idx
.
type
.
dtype
==
"bool"
)
or
(
isinstance
(
adv_idx
.
type
,
SliceType
)
and
not
is_full_slice
(
adv_idx
))
)
for
adv_idx
in
adv_idxs
)
or
_non_consecutive_adv_indexing
(
adv_idxs
):
return
None
for
first_adv_idx_dim
,
adv_idx
in
enumerate
(
adv_idxs
):
# We already made sure there were only None slices besides integer indexes
if
isinstance
(
adv_idx
,
TensorVariabl
e
):
if
isinstance
(
adv_idx
.
type
,
TensorTyp
e
):
break
else
:
# no-break
# Not sure if this should ever happen, but better safe than sorry
...
...
@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_indexed
)
x_after_index_lift
=
expand_dims
(
x_indexed
,
dropped_dims
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
ndex_var
s
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
dx
s
)
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_after_adv_idx
)
new_out
=
squeeze
(
x_after_adv_idx
[
basic_idxs_kept
],
dropped_dims
)
...
...
pytensor/tensor/rewriting/uncanonicalize.py
浏览文件 @
cc6bed1a
...
...
@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from
pytensor.tensor.math
import
Min
,
neg
from
pytensor.tensor.rewriting.basic
import
register_uncanonicalize
from
pytensor.tensor.shape
import
Reshape
,
reshape
from
pytensor.tensor.subtensor
import
Subtensor
,
indices_from_subtensor
from
pytensor.tensor.subtensor
import
Subtensor
@register_uncanonicalize
...
...
@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
if
not
all
(
broadcastable
[
i
]
for
i
in
missing_dims
):
return
False
# create a new index tuple for a new Subtensor
# Reconstruct the full indices from the subtensor node, then replace
# dimensions that are being dropped by dimshuffle with scalar index 0
x
=
input_
.
owner
.
inputs
[
0
]
indices
=
list
(
indices_from_subtensor
(
input_
.
owner
.
inputs
[
1
:],
input_
.
owner
.
op
.
idx_list
)
)
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list
=
list
(
input_
.
owner
.
op
.
idx_list
)
new_inputs
=
[
input_
.
owner
.
inputs
[
0
]]
zero
=
constant
(
0
)
# Track which output dimension each index corresponds to
# Scalar indices remove dimensions, slices keep them
output_dim
=
0
for
i
,
idx
in
enumerate
(
indices
):
j
=
0
slice_i
=
-
1
subtensor_removed_dims
=
0
for
i
,
idx
in
enumerate
(
input_
.
owner
.
op
.
idx_list
):
if
isinstance
(
idx
,
slice
):
# This slice produces an output dimension
if
output_dim
in
missing_dims
:
#
This output dimension is being dropped, so replace slice with scalar
slice_i
+=
1
if
slice_i
in
missing_dims
:
#
Missing dim is a slice(None), remove by indexing by 0
if
idx
==
slice
(
None
):
indices
[
i
]
=
zero
new_idx_list
[
i
]
=
zero
new_inputs
+=
[
zero
]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else
:
# Use the start of the slice (or 0 if None)
indices
[
i
]
=
idx
.
start
if
idx
.
start
is
not
None
else
zero
output_dim
+=
1
# Scalar indices don't contribute to output dimensions
# Handle trailing dimensions that weren't explicitly indexed
for
input_dim
in
range
(
len
(
indices
),
x
.
ndim
):
if
output_dim
in
missing_dims
:
# This unindexed dimension is being dropped, index with 0
indices
.
append
(
zero
)
if
idx
.
start
is
None
:
start
=
zero
else
:
start
=
input_
.
owner
.
inputs
[
1
+
j
]
j
+=
1
new_idx_list
[
i
]
=
start
new_inputs
+=
[
start
]
# Ignore useless stop and step input if there is one
for
slice_attr
in
(
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
j
+=
1
# Keep non-dropped slice inputs
else
:
for
slice_attr
in
(
"start"
,
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
j
+=
1
# Keep non-dropped non-slice inputs
else
:
# This unindexed dimension is kept, index with slice(None)
indices
.
append
(
slice
(
None
))
output_dim
+=
1
return
[
x
[
tuple
(
indices
)]]
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
j
+=
1
subtensor_removed_dims
+=
1
# Verify the trailing dimensions the subtensor didn't look at.
for
idx
in
range
(
len
(
input_
.
owner
.
op
.
idx_list
),
new_inputs
[
0
]
.
ndim
):
if
(
idx
-
subtensor_removed_dims
)
in
missing_dims
:
while
len
(
new_idx_list
)
<
idx
:
new_idx_list
.
append
(
slice
(
None
))
new_idx_list
.
append
(
zero
)
new_inputs
.
append
(
zero
)
return
[
Subtensor
(
new_idx_list
)(
*
new_inputs
)]
return
False
pytensor/tensor/subtensor.py
浏览文件 @
cc6bed1a
import
logging
import
sys
import
warnings
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
itertools
import
chain
,
groupby
,
zip_longest
from
typing
import
TypeVar
,
cast
,
overload
from
typing
import
cast
,
overload
import
numpy
as
np
from
numpy.lib.array_utils
import
normalize_axis_tuple
...
...
@@ -15,6 +15,7 @@ from pytensor.gradient import disconnected_type
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.type
import
Type
from
pytensor.graph.utils
import
MethodNotDefined
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
...
...
@@ -37,114 +38,117 @@ from pytensor.tensor.basic import (
)
from
pytensor.tensor.blockwise
import
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
,
NotScalarConstantError
from
pytensor.tensor.math
import
add
,
clip
from
pytensor.tensor.shape
import
(
Reshape
,
Shape_i
,
specify_broadcastable
,
)
from
pytensor.tensor.shape
import
Reshape
,
Shape_i
,
specify_broadcastable
from
pytensor.tensor.type
import
(
TensorType
,
bscalar
,
complex_dtypes
,
cscalar
,
discrete_dtypes
,
dscalar
,
fscalar
,
integer_dtypes
,
iscalar
,
lscalar
,
tensor
,
ubscalar
,
uiscalar
,
ulscalar
,
uwscalar
,
wscalar
,
zscalar
,
)
from
pytensor.tensor.type_other
import
(
MakeSlice
,
NoneConst
,
NoneSliceConst
,
NoneTypeT
,
SliceConstant
,
SliceType
,
make_slice
,
)
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.utils
import
unzip
_logger
=
logging
.
getLogger
(
"pytensor.tensor.subtensor"
)
T
=
TypeVar
(
"T"
)
def
flatten_index_variables
(
idx_vars
:
Sequence
[
T
|
None
|
slice
],
)
->
tuple
[
list
[
int
|
slice
],
list
[
T
]]:
counter
=
0
idx_list
:
list
[
int
|
slice
]
=
[]
flat_vars
=
[]
for
idx_var
in
idx_vars
:
if
isinstance
(
idx_var
,
slice
):
slice_idx_list
:
list
[
None
|
int
]
=
[]
for
arg_entry
in
(
idx_var
.
start
,
idx_var
.
stop
,
idx_var
.
step
):
if
arg_entry
is
None
or
(
isinstance
(
arg_entry
,
Variable
)
and
isinstance
(
arg_entry
.
type
,
NoneTypeT
)
):
slice_idx_list
.
append
(
None
)
else
:
flat_vars
.
append
(
arg_entry
)
slice_idx_list
.
append
(
counter
)
counter
+=
1
idx_list
.
append
(
slice
(
*
slice_idx_list
))
else
:
flat_vars
.
append
(
idx_var
)
idx_list
.
append
(
counter
)
counter
+=
1
return
idx_list
,
flat_vars
def
unflatten_index_variables
(
flat_indices
:
Sequence
[
T
],
idx_list
:
Sequence
[
slice
|
int
],
)
->
tuple
[
slice
|
T
,
...
]:
indices
:
list
[
T
|
slice
]
=
[]
for
idx_entry
in
idx_list
:
if
isinstance
(
idx_entry
,
int
):
indices
.
append
(
flat_indices
[
idx_entry
])
elif
isinstance
(
idx_entry
,
slice
):
start
,
stop
,
step
=
idx_entry
.
start
,
idx_entry
.
stop
,
idx_entry
.
step
indices
.
append
(
slice
(
None
if
idx_entry
.
start
is
None
else
flat_indices
[
start
],
None
if
idx_entry
.
stop
is
None
else
flat_indices
[
stop
],
None
if
idx_entry
.
step
is
None
else
flat_indices
[
step
],
)
)
else
:
raise
ValueError
(
f
"idx_entry must be int or slice, got {type(idx_entry)}"
)
return
tuple
(
indices
)
invalid_scal_types
=
(
ps
.
float64
,
ps
.
float32
,
ps
.
float16
)
scal_types
=
(
ps
.
int64
,
ps
.
int32
,
ps
.
int16
,
ps
.
int8
,
ps
.
uint64
,
ps
.
uint32
,
ps
.
uint16
,
ps
.
uint8
,
)
tensor_types
=
(
lscalar
,
iscalar
,
wscalar
,
bscalar
,
ulscalar
,
uiscalar
,
uwscalar
,
ubscalar
,
)
invalid_tensor_types
=
(
fscalar
,
dscalar
,
cscalar
,
zscalar
,
)
def
indices_from_subtensor
(
op_indices
:
Sequence
[
Variable
],
idx_list
:
tuple
[
slice
|
int
,
...
]
,
op_indices
:
Iterable
[
ScalarConstant
],
idx_list
:
list
[
Type
|
slice
|
Variable
]
|
None
,
)
->
tuple
[
slice
|
Variable
,
...
]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
----------
==========
op_indices
The flattened indices obtained from ``x.inputs``, when ``x`` is a ``*Subtensor*`` node.
The flattened indices obtained from ``x.inputs``, when ``x`` is a
``*Subtensor*`` node.
idx_list
The values describing each dimension's index. This is obtained from
``op.idx_list``. Entries can be:
- Integer positions (indices into op_indices)
- slice objects with int/None components
Returns
-------
tuple[slice | Variable, ...]
A tuple containing a mix of ``slice`` objects and ``Variable`` objects.
Each element corresponds to one indexing dimension:
- ``slice`` objects for slice-based indexing (e.g., ``x[1:3]``)
- ``Variable`` objects for scalar or array-based indexing
Callers should handle both types when iterating over the result.
The values describing the types of each dimension's index. This is
obtained from ``op.idx_list``, when ``op`` is a ``*Subtensor*``
``Op``.
Example
-------
=======
array, *op_indices = subtensor_node.inputs
indices = indices_from_subtensor(op_indices, subtensor_node.op.idx_list)
idx_list = getattr(subtensor_node.op, "idx_list", None)
indices = indices_from_subtensor(op_indices, idx_list)
"""
return
unflatten_index_variables
(
op_indices
,
idx_list
)
def
convert_indices
(
indices
,
entry
):
"""Reconstruct ``*Subtensor*`` index input parameter entries."""
if
indices
and
isinstance
(
entry
,
Type
):
rval
=
indices
.
pop
(
0
)
return
rval
elif
isinstance
(
entry
,
slice
):
return
slice
(
convert_indices
(
indices
,
entry
.
start
),
convert_indices
(
indices
,
entry
.
stop
),
convert_indices
(
indices
,
entry
.
step
),
)
else
:
return
entry
op_indices
=
list
(
op_indices
)
return
(
tuple
(
convert_indices
(
op_indices
,
idx
)
for
idx
in
idx_list
)
if
idx_list
else
tuple
(
op_indices
)
)
def
as_index_constant
(
...
...
@@ -178,7 +182,7 @@ def as_index_literal(idx: None) -> None: ...
@overload
def
as_index_literal
(
idx
:
slice
)
->
slice
:
...
def
as_index_literal
(
idx
:
slice
|
SliceConstant
)
->
slice
:
...
@overload
...
...
@@ -190,7 +194,14 @@ def as_index_literal(idx: Variable): ...
def
as_index_literal
(
idx
:
None
|
int
|
np
.
integer
|
slice
|
ScalarConstant
|
TensorConstant
|
Variable
,
idx
:
None
|
int
|
np
.
integer
|
slice
|
SliceConstant
|
ScalarConstant
|
TensorConstant
|
Variable
,
)
->
int
|
np
.
integer
|
slice
|
None
:
"""Convert a symbolic index element to its Python equivalent.
...
...
@@ -213,6 +224,9 @@ def as_index_literal(
if
not
isinstance
(
idx
,
Variable
):
raise
TypeError
(
f
"Not an index element: {idx}"
)
if
isinstance
(
idx
.
type
,
NoneTypeT
):
return
None
if
isinstance
(
idx
,
ScalarConstant
):
return
cast
(
int
,
idx
.
data
)
...
...
@@ -226,6 +240,13 @@ def as_index_literal(
if
isinstance
(
idx
,
TensorConstant
):
return
cast
(
int
,
idx
.
data
.
item
())
if
isinstance
(
idx
,
SliceConstant
):
return
cast
(
slice
,
idx
.
data
)
if
isinstance
(
idx
.
type
,
SliceType
):
assert
idx
.
owner
is
not
None
return
slice
(
*
map
(
as_index_literal
,
idx
.
owner
.
inputs
))
# Other kinds of variables are not supported
raise
NotScalarConstantError
()
...
...
@@ -254,8 +275,10 @@ def get_canonical_form_slice(
)
->
tuple
[
slice
|
TensorVariable
,
int
|
TensorVariable
]:
"""Convert indices or slices to canonical form.
Handles slice objects with ScalarVariable (including ScalarConstant) or None components.
Vector indices and advanced indexing operations are handled separately by AdvancedSubtensor.
Scalar integer indices or python Slices with Scalar/None attributes
used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.
Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
...
...
@@ -469,20 +492,16 @@ def get_canonical_form_slice(
return
slice
(
nw_start
,
nw_stop
,
nw_step
),
1
def
slice_len
(
slc
,
n
):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
def
range_len
(
slc
):
"""Length of a `range` object.
Adapted from CPython.
"""
from
pytensor.tensor
import
and_
,
gt
,
lt
,
switch
# TODO: Do we need to do this or should we expect `slc` to already be canonicalized?
canon_slc
,
_
=
get_canonical_form_slice
(
slc
,
n
)
start
,
stop
,
step
=
tuple
(
as_index_constant
(
a
)
for
a
in
[
canon_slc
.
start
,
canon_slc
.
stop
,
canon_
slc
.
step
]
as_index_constant
(
a
)
for
a
in
[
slc
.
start
,
slc
.
stop
,
slc
.
step
]
)
return
switch
(
and_
(
gt
(
step
,
0
),
lt
(
start
,
stop
)),
...
...
@@ -495,6 +514,31 @@ def slice_len(slc, n):
)
def
slice_len
(
slc
,
n
):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
"""
# TODO: Do we need to do this or should we expect `slc` to
# already be canonicalized?
canon_slc
,
_
=
get_canonical_form_slice
(
slc
,
n
)
return
range_len
(
canon_slc
)
def
is_basic_idx
(
idx
):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
"""
return
isinstance
(
idx
,
slice
|
type
(
None
))
or
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
|
NoneTypeT
)
def
basic_shape
(
shape
,
indices
):
r"""Computes the shape resulting from basic NumPy indexing.
...
...
@@ -513,8 +557,25 @@ def basic_shape(shape, indices):
for
n
,
idx
in
zip
(
shape
[:
len
(
indices
)],
indices
,
strict
=
True
):
if
isinstance
(
idx
,
slice
):
res_shape
+=
(
slice_len
(
idx
,
n
),)
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
if
idx
.
owner
is
None
:
if
not
isinstance
(
idx
,
Constant
):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape
+=
(
None
,)
continue
else
:
sl
:
slice
=
idx
.
data
slice_inputs
=
(
sl
.
start
,
sl
.
stop
,
sl
.
step
)
elif
isinstance
(
idx
.
owner
.
op
,
MakeSlice
):
slice_inputs
=
idx
.
owner
.
inputs
else
:
raise
ValueError
(
f
"Unexpected Slice producing Op {idx.owner.op}"
)
res_shape
+=
(
slice_len
(
slice
(
*
slice_inputs
),
n
),)
elif
idx
is
None
:
res_shape
+=
(
ps
.
ScalarConstant
(
ps
.
int64
,
1
),)
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
NoneTypeT
):
res_shape
+=
(
ps
.
ScalarConstant
(
ps
.
int64
,
1
),)
else
:
raise
ValueError
(
f
"Invalid index type: {idx}"
)
return
res_shape
...
...
@@ -532,12 +593,14 @@ def group_indices(indices):
"""
idx_groups
=
[]
dim_num
=
-
1
for
basic
,
grp_indices
in
groupby
(
indices
,
key
=
lambda
x
:
isinstance
(
x
,
slice
)
):
for
basic
,
grp_indices
in
groupby
(
indices
,
key
=
is_basic_idx
):
enum_grp_indices
=
[]
for
idx
in
grp_indices
:
# We "zip" the dimension number to each index, which means we can't
# count indices that add new axes
if
idx
is
not
None
:
if
(
idx
is
not
None
)
and
not
isinstance
(
getattr
(
idx
,
"type"
,
None
),
NoneTypeT
):
dim_num
+=
1
enum_grp_indices
.
append
((
dim_num
,
idx
))
...
...
@@ -584,7 +647,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
idx_groups
=
sorted
(
idx_groups
,
key
=
lambda
x
:
x
[
0
])
idx_groups
=
groupby
(
chain
.
from_iterable
(
d_idx
for
_
,
d_idx
in
idx_groups
),
key
=
lambda
x
:
is
instance
(
x
[
1
],
slice
),
key
=
lambda
x
:
is
_basic_idx
(
x
[
1
]
),
)
for
basic
,
grp_dim_indices
in
idx_groups
:
...
...
@@ -644,6 +707,72 @@ def get_slice_elements(
return
ret
def
index_vars_to_types
(
entry
,
slice_ok
=
True
):
r"""Change references to `Variable`s into references to `Type`s.
The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It
is not unique to each `Apply` node, so it should not refer to specific
`Variable`s.
TODO WRITEME: This function also accepts an `entry` already being a `Type`;
when would that happen?
"""
if
(
isinstance
(
entry
,
np
.
ndarray
|
Variable
)
and
hasattr
(
entry
,
"dtype"
)
and
entry
.
dtype
==
"bool"
):
raise
AdvancedIndexingError
(
"Invalid index type or slice for Subtensor"
)
if
isinstance
(
entry
,
Variable
)
and
(
entry
.
type
in
invalid_scal_types
or
entry
.
type
in
invalid_tensor_types
):
raise
TypeError
(
"Expected an integer"
)
if
isinstance
(
entry
,
Variable
)
and
entry
.
type
in
scal_types
:
return
entry
.
type
elif
isinstance
(
entry
,
Type
)
and
entry
in
scal_types
:
return
entry
if
(
isinstance
(
entry
,
Variable
)
and
entry
.
type
in
tensor_types
and
all
(
entry
.
type
.
broadcastable
)
):
return
ps
.
get_scalar_type
(
entry
.
type
.
dtype
)
elif
isinstance
(
entry
,
Type
)
and
entry
in
tensor_types
and
all
(
entry
.
broadcastable
):
return
ps
.
get_scalar_type
(
entry
.
dtype
)
elif
slice_ok
and
isinstance
(
entry
,
slice
):
a
=
entry
.
start
b
=
entry
.
stop
c
=
entry
.
step
if
a
is
not
None
:
slice_a
=
index_vars_to_types
(
a
,
False
)
else
:
slice_a
=
None
if
b
is
not
None
and
b
!=
sys
.
maxsize
:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
slice_b
=
index_vars_to_types
(
b
,
False
)
else
:
slice_b
=
None
if
c
is
not
None
:
slice_c
=
index_vars_to_types
(
c
,
False
)
else
:
slice_c
=
None
return
slice
(
slice_a
,
slice_b
,
slice_c
)
elif
isinstance
(
entry
,
int
|
np
.
integer
):
raise
TypeError
()
else
:
raise
AdvancedIndexingError
(
"Invalid index type or slice for Subtensor"
)
def
get_constant_idx
(
idx_list
,
inputs
,
allow_partial
=
False
,
only_process_constants
=
False
,
elemwise
=
True
):
...
...
@@ -674,7 +803,7 @@ def get_constant_idx(
>>> a = matrix("a")
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(
0, slice(1, 2
, None))
(
ScalarType(int64), slice(ScalarType(int64), ScalarType(int64)
, None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(np.int64(1), np.int64(3), None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
...
...
@@ -706,11 +835,15 @@ def get_constant_idx(
return
list
(
map
(
conv
,
real_idx
))
def
as_scalar_index_variable
(
idx
)
->
ps
.
ScalarVariable
:
idx
=
ps
.
as_scalar
(
idx
)
if
idx
.
type
.
dtype
not
in
integer_dtypes
:
raise
TypeError
(
"basic indices must be integers"
)
return
idx
# type: ignore[no-any-return]
def
as_nontensor_scalar
(
a
:
Variable
)
->
ps
.
ScalarVariable
:
"""Convert a value to a `ScalarType` variable."""
# Since ps.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if
isinstance
(
a
,
Variable
)
and
isinstance
(
a
.
type
,
TensorType
):
return
pytensor
.
tensor
.
scalar_from_tensor
(
a
)
else
:
return
ps
.
as_scalar
(
a
)
def
slice_static_length
(
slc
,
dim_length
):
...
...
@@ -731,71 +864,17 @@ def slice_static_length(slc, dim_length):
return
len
(
range
(
*
slice
(
*
entries
)
.
indices
(
dim_length
)))
class
BaseSubtensor
:
"""Base class for Subtensor operations that handles idx_list and hash/equality."""
def
__init__
(
self
,
idx_list
:
Sequence
[
int
|
slice
]):
index_counter
=
-
1
for
idx_entry
in
idx_list
:
if
isinstance
(
idx_entry
,
int
):
if
idx_entry
!=
(
index_counter
+
1
):
raise
ValueError
(
f
"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter
=
idx_entry
elif
isinstance
(
idx_entry
,
slice
):
for
slice_idx_entry
in
(
idx_entry
.
start
,
idx_entry
.
stop
,
idx_entry
.
step
,
):
if
slice_idx_entry
is
not
None
:
if
not
isinstance
(
slice_idx_entry
,
int
):
raise
ValueError
(
f
"idx_list slice entries must be None or integer, got {slice_idx_entry} in {idx_entry}"
)
if
slice_idx_entry
!=
(
index_counter
+
1
):
raise
ValueError
(
f
"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter
=
slice_idx_entry
else
:
raise
ValueError
(
f
"idx_list entries must be int or slice, got {idx_entry}"
)
self
.
n_index_vars
=
index_counter
+
1
self
.
idx_list
=
tuple
(
idx_list
)
def
_hashable_idx_list
(
self
):
"""Return a hashable version of idx_list (slices converted to tuples).
Slices are not hashable in Python < 3.12, so we convert them to tuples.
"""
return
tuple
(
(
slice
,
entry
.
start
,
entry
.
stop
,
entry
.
step
)
if
isinstance
(
entry
,
slice
)
else
entry
for
entry
in
self
.
idx_list
)
def
__hash__
(
self
):
# Temporary workaround: slices are hashable in Python 3.12+
props_values
=
tuple
(
self
.
_hashable_idx_list
()
if
prop
==
"idx_list"
else
getattr
(
self
,
prop
)
for
prop
in
self
.
__props__
)
return
hash
((
type
(
self
),
props_values
))
class
Subtensor
(
BaseSubtensor
,
COp
):
class
Subtensor
(
COp
):
"""Basic NumPy indexing operator."""
check_input
=
False
view_map
=
{
0
:
[
0
]}
_f16_ok
=
True
__props__
=
(
"idx_list"
,)
__hash__
=
BaseSubtensor
.
__hash__
def
__init__
(
self
,
idx_list
):
# TODO: Provide the type of `self.idx_list`
self
.
idx_list
=
tuple
(
map
(
index_vars_to_types
,
idx_list
))
def
make_node
(
self
,
x
,
*
inputs
):
"""
...
...
@@ -808,16 +887,23 @@ class Subtensor(BaseSubtensor, COp):
"""
x
=
as_tensor_variable
(
x
)
inputs
=
tuple
(
as_scalar_index_variable
(
a
)
for
a
in
inputs
)
inputs
=
tuple
(
as_nontensor_scalar
(
a
)
for
a
in
inputs
)
idx_list
=
list
(
self
.
idx_list
)
if
len
(
idx_list
)
>
x
.
type
.
ndim
:
raise
IndexError
(
"too many indices for array"
)
input_
position
s
=
get_slice_elements
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
int
)
input_
type
s
=
get_slice_elements
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
Type
)
)
assert
len
(
inputs
)
==
len
(
input_positions
)
assert
len
(
inputs
)
==
len
(
input_types
)
for
input
,
expected_type
in
zip
(
inputs
,
input_types
,
strict
=
True
):
if
not
expected_type
.
is_super
(
input
.
type
):
raise
TypeError
(
f
"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
)
padded
=
[
*
indices_from_subtensor
(
inputs
,
self
.
idx_list
),
...
...
@@ -838,10 +924,13 @@ class Subtensor(BaseSubtensor, COp):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
x
,
*
index_variables
=
inputs
x
=
inputs
[
0
]
cdata
=
get_idx_list
(
inputs
,
self
.
idx_list
)
if
len
(
cdata
)
==
1
:
cdata
=
cdata
[
0
]
cdata
=
unflatten_index_variables
(
index_variables
,
self
.
idx_list
)
out
[
0
]
=
np
.
asarray
(
x
.
__getitem__
(
tuple
(
cdata
)))
out
[
0
]
=
np
.
asarray
(
x
.
__getitem__
(
cdata
))
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
_is_constant
(
const
,
x
):
...
...
@@ -889,7 +978,8 @@ class Subtensor(BaseSubtensor, COp):
def
grad
(
self
,
inputs
,
grads
):
(
gz
,)
=
grads
x
,
*
index_variables
=
inputs
x
=
inputs
[
0
]
rest
=
inputs
[
1
:]
if
x
.
dtype
in
discrete_dtypes
:
first
=
x
.
zeros_like
(
dtype
=
config
.
floatX
)
else
:
...
...
@@ -898,28 +988,43 @@ class Subtensor(BaseSubtensor, COp):
# We have an optimization that will convert this to a
# set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
index_variables
)
return
[
first
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
index_variables
)))]
first
=
IncSubtensor
(
self
.
idx_list
)(
x
.
zeros_like
(),
gz
,
*
rest
)
return
[
first
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
rest
)))]
def
connection_pattern
(
self
,
node
):
_x
,
*
index_variables
=
node
.
inputs
rval
=
[[
True
],
*
([
False
]
for
_
in
index_variables
)]
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
def
__hash__
(
self
):
msg
=
[]
for
entry
in
self
.
idx_list
:
if
isinstance
(
entry
,
slice
):
msg
+=
[(
entry
.
start
,
entry
.
stop
,
entry
.
step
)]
else
:
msg
+=
[
entry
]
idx_list
=
tuple
(
msg
)
# backport
# idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return
hash
(
idx_list
)
@staticmethod
def
str_from_slice
(
entry
):
if
entry
.
step
is
not
None
:
if
entry
.
step
:
return
":"
.
join
(
(
"start"
if
entry
.
start
is
not
None
else
""
,
"stop"
if
entry
.
stop
is
not
None
else
""
,
"start"
if
entry
.
start
else
""
,
"stop"
if
entry
.
stop
else
""
,
"step"
,
)
)
if
entry
.
stop
is
not
None
:
return
f
"{'start' if entry.start
is not None
else ''}:stop"
if
entry
.
start
is
not
None
:
if
entry
.
stop
:
return
f
"{'start' if entry.start else ''}:stop"
if
entry
.
start
:
return
"start:"
return
":"
...
...
@@ -1002,7 +1107,12 @@ class Subtensor(BaseSubtensor, COp):
return
pos
[
1
]
def
init_entry
(
entry
,
depth
=
0
):
if
isinstance
(
entry
,
int
):
if
isinstance
(
entry
,
np
.
integer
|
int
):
init_cmds
.
append
(
f
"subtensor_spec[{spec_pos()}] = {entry};"
)
inc_spec_pos
(
1
)
if
depth
==
0
:
is_slice
.
append
(
0
)
elif
isinstance
(
entry
,
Type
):
init_cmds
.
append
(
f
"subtensor_spec[{spec_pos()}] = {inputs[input_pos()]};"
)
...
...
@@ -1265,58 +1375,7 @@ class Subtensor(BaseSubtensor, COp):
# (they should be defaulted to zeros_like by the global R_op)
if
eval_points
[
0
]
is
None
:
return
[
None
]
_x
,
*
index_variables
=
inputs
return
self
(
eval_points
[
0
],
*
index_variables
,
return_list
=
True
)
def
basic_subtensor
(
x
,
*
index_variables
):
idx_list
,
flat_index_vars
=
flatten_index_variables
(
index_variables
)
return
Subtensor
(
idx_list
)(
x
,
*
flat_index_vars
)
@_get_vector_length.register
(
Subtensor
)
# type: ignore
def
_get_vector_length_Subtensor
(
op
,
var
):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try
:
indices
=
get_idx_list
(
var
.
owner
.
inputs
,
var
.
owner
.
op
.
idx_list
)
start
=
(
None
if
indices
[
0
]
.
start
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
start
)
)
stop
=
(
None
if
indices
[
0
]
.
stop
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
stop
)
)
step
=
(
None
if
indices
[
0
]
.
step
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
step
)
)
if
start
==
stop
:
return
0
arg_len
=
get_vector_length
(
var
.
owner
.
inputs
[
0
])
return
len
(
range
(
*
slice
(
start
,
stop
,
step
)
.
indices
(
arg_len
)))
except
(
ValueError
,
NotScalarConstantError
):
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
@_vectorize_node.register
(
Subtensor
)
def
vectorize_subtensor
(
op
:
Subtensor
,
node
,
batch_x
,
*
batch_idxs
):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if
any
(
batch_inp
.
type
.
ndim
>
0
for
batch_inp
in
batch_idxs
):
return
vectorize_node_fallback
(
op
,
node
,
batch_x
,
*
batch_idxs
)
old_x
,
*
_
=
node
.
inputs
batch_ndims
=
batch_x
.
type
.
ndim
-
old_x
.
type
.
ndim
new_idx_list
=
(
slice
(
None
),)
*
batch_ndims
+
op
.
idx_list
return
Subtensor
(
new_idx_list
)
.
make_node
(
batch_x
,
*
batch_idxs
)
return
self
(
eval_points
[
0
],
*
inputs
[
1
:],
return_list
=
True
)
class
SubtensorPrinter
(
Printer
):
...
...
@@ -1328,28 +1387,25 @@ class SubtensorPrinter(Printer):
input
=
inputs
.
pop
(
0
)
sidxs
=
[]
getattr
(
pstate
,
"precedence"
,
None
)
def
process_slice_component
(
comp
):
"""Process a slice component, returning string representation."""
if
comp
is
None
:
return
""
elif
isinstance
(
comp
,
int
):
with
set_precedence
(
pstate
):
return
pstate
.
pprinter
.
process
(
inputs
.
pop
(
0
))
else
:
return
str
(
comp
)
for
entry
in
idxs
:
if
isinstance
(
entry
,
int
):
if
isinstance
(
entry
,
ps
.
ScalarType
):
with
set_precedence
(
pstate
):
sidxs
.
append
(
pstate
.
pprinter
.
process
(
inputs
.
pop
(
0
)))
sidxs
.
append
(
pstate
.
pprinter
.
process
(
inputs
.
pop
()))
elif
isinstance
(
entry
,
slice
):
msg1
=
process_slice_component
(
entry
.
start
)
msg2
=
process_slice_component
(
entry
.
stop
)
if
entry
.
start
is
None
or
entry
.
start
==
0
:
msg1
=
""
else
:
msg1
=
entry
.
start
if
entry
.
stop
is
None
or
entry
.
stop
==
sys
.
maxsize
:
msg2
=
""
else
:
msg2
=
entry
.
stop
if
entry
.
step
is
None
:
msg3
=
""
else
:
msg3
=
f
":{
process_slice_component(entry.step)
}"
msg3
=
f
":{
entry.step
}"
sidxs
.
append
(
f
"{msg1}:{msg2}{msg3}"
)
...
...
@@ -1362,97 +1418,336 @@ class SubtensorPrinter(Printer):
pprint
.
assign
(
Subtensor
,
SubtensorPrinter
())
class
IncSubtensor
(
BaseSubtensor
,
COp
):
"""
Increment a subtensor.
This is like numpy's
@_vectorize_node.register
(
Subtensor
)
def
vectorize_subtensor
(
op
:
Subtensor
,
node
,
batch_x
,
*
batch_idxs
):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
x[i,j,k] += y
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if
any
(
batch_inp
.
type
.
ndim
>
0
for
batch_inp
in
batch_idxs
):
return
vectorize_node_fallback
(
op
,
node
,
batch_x
,
*
batch_idxs
)
It is used internally to implement the gradient on SubTensor.
old_x
,
*
_
=
node
.
inputs
batch_ndims
=
batch_x
.
type
.
ndim
-
old_x
.
type
.
ndim
new_idx_list
=
(
slice
(
None
),)
*
batch_ndims
+
op
.
idx_list
return
Subtensor
(
new_idx_list
)
.
make_node
(
batch_x
,
*
batch_idxs
)
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
def
set_subtensor
(
x
,
y
,
inplace
=
False
,
tolerate_inplace_aliasing
=
False
):
"""
Return x with the given subtensor overwritten by y.
check_input
=
False
__props__
=
(
"idx_list"
,
"inplace"
,
"set_instead_of_inc"
,
"destroyhandler_tolerate_aliased"
,
)
__hash__
=
BaseSubtensor
.
__hash__
def
__init__
(
self
,
idx_list
,
inplace
=
False
,
set_instead_of_inc
=
False
,
destroyhandler_tolerate_aliased
=
None
,
):
if
destroyhandler_tolerate_aliased
is
None
:
destroyhandler_tolerate_aliased
=
()
super
()
.
__init__
(
idx_list
)
self
.
inplace
=
inplace
if
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
destroyhandler_tolerate_aliased
=
tuple
(
destroyhandler_tolerate_aliased
)
self
.
set_instead_of_inc
=
set_instead_of_inc
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
def
__str__
(
self
):
name
=
"SetSubtensor"
if
self
.
set_instead_of_inc
else
"IncSubtensor"
return
f
"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
def
make_node
(
self
,
x
,
y
,
*
inputs
):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs
The indeces/slices list to increment in combination with idx_list.
.. code-block:: python
E.g. self._idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))
tell to use inputs[0] as the first dim.
"""
x
,
y
=
map
(
as_tensor_variable
,
[
x
,
y
])
if
y
.
ndim
>
x
.
ndim
:
raise
ValueError
(
f
"Trying to increment a {int(x.ndim)}-dimensional "
f
"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs
=
tuple
(
map
(
as_scalar_index_variable
,
inputs
))
from pytensor.tensor import set_subtensor, vector
idx_list
=
list
(
self
.
idx_list
)
if
len
(
idx_list
)
>
x
.
type
.
ndim
:
raise
IndexError
(
"too many indices for array"
)
r = vector("r")
new_r = set_subtensor(r[10:], 5)
if
len
(
inputs
)
!=
self
.
n_index_vars
:
raise
ValueError
(
"Not enough inputs to fill in the Subtensor template."
,
inputs
,
idx_list
)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
return
Apply
(
self
,
(
x
,
y
,
*
inputs
),
[
x
.
type
()])
"""
return
inc_subtensor
(
x
,
y
,
inplace
,
set_instead_of_inc
=
True
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
)
def
decl_view
(
self
):
return
"PyArrayObject * zview = NULL;"
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
inc_subtensor
(
x
,
y
,
inplace
=
False
,
set_instead_of_inc
=
False
,
tolerate_inplace_aliasing
=
False
,
ignore_duplicates
=
False
,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x
=
as_tensor_variable
(
x
)
y
=
as_tensor_variable
(
y
)
if
y
.
ndim
>
x
.
ndim
:
raise
TypeError
(
f
"Trying to increment a {int(x.ndim)}-dimensional "
f
"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset
=
x
.
ndim
-
y
.
ndim
for
dim
in
range
(
y
.
ndim
):
if
x
.
broadcastable
[
dim
+
dim_offset
]
and
not
y
.
broadcastable
[
dim
]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y
=
specify_broadcastable
(
y
,
dim
)
if
x
.
owner
is
None
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
# retrieve idx_list from x.owner
if
isinstance
(
x
.
owner
.
op
,
Subtensor
):
if
tolerate_inplace_aliasing
:
destroyhandler_tolerate_aliased
=
[[
0
,
1
]]
else
:
destroyhandler_tolerate_aliased
=
[]
the_op
=
IncSubtensor
(
x
.
owner
.
op
.
idx_list
,
inplace
,
set_instead_of_inc
,
destroyhandler_tolerate_aliased
=
destroyhandler_tolerate_aliased
,
)
real_x
=
x
.
owner
.
inputs
[
0
]
real_idxargs
=
x
.
owner
.
inputs
[
1
:]
return
the_op
(
real_x
,
y
,
*
real_idxargs
)
elif
isinstance
(
x
.
owner
.
op
,
AdvancedSubtensor1
):
real_x
=
x
.
owner
.
inputs
[
0
]
ilist
=
x
.
owner
.
inputs
[
1
]
if
ignore_duplicates
:
the_op
=
AdvancedIncSubtensor
(
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
ignore_duplicates
=
True
)
else
:
the_op
=
AdvancedIncSubtensor1
(
inplace
,
set_instead_of_inc
=
set_instead_of_inc
)
return
the_op
(
real_x
,
y
,
ilist
)
elif
isinstance
(
x
.
owner
.
op
,
AdvancedSubtensor
):
real_x
=
x
.
owner
.
inputs
[
0
]
ilist
=
x
.
owner
.
inputs
[
1
:]
the_op
=
AdvancedIncSubtensor
(
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
ignore_duplicates
=
ignore_duplicates
,
)
return
the_op
(
real_x
,
y
,
*
ilist
)
elif
isinstance
(
x
.
owner
.
op
,
DimShuffle
):
inner_x
=
x
.
owner
.
inputs
[
0
]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order
=
x
.
owner
.
op
.
new_order
y_order
=
[
"x"
]
*
x
.
ndim
for
i
,
v
in
enumerate
(
x_order
):
if
v
!=
"x"
and
(
v
-
dim_offset
)
>=
0
:
y_order
[
v
-
dim_offset
]
=
i
inner_incsubtensor
=
inc_subtensor
(
inner_x
,
y
.
dimshuffle
(
y_order
),
inplace
=
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
ignore_duplicates
=
ignore_duplicates
,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return
inner_incsubtensor
.
dimshuffle
(
x
.
owner
.
op
.
new_order
)
elif
isinstance
(
x
.
owner
.
op
,
Reshape
):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x
=
x
.
owner
.
inputs
[
0
]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if
y
.
ndim
>
0
:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y
=
alloc
(
y
,
*
[
x
.
shape
[
i
]
for
i
in
range
(
x
.
ndim
)])
flattened_y
=
expanded_y
.
reshape
(
inner_x
.
shape
)
else
:
flattened_y
=
y
inner_incsubtensor
=
inc_subtensor
(
inner_x
,
flattened_y
,
inplace
=
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
ignore_duplicates
=
ignore_duplicates
,
)
return
inner_incsubtensor
else
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
class
IncSubtensor
(
COp
):
"""
Increment a subtensor.
This is like numpy's
x[i,j,k] += y
It is used internally to implement the gradient on SubTensor.
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
"""
check_input
=
False
__props__
=
(
"idx_list"
,
"inplace"
,
"set_instead_of_inc"
)
def
__init__
(
self
,
idx_list
,
inplace
=
False
,
set_instead_of_inc
=
False
,
destroyhandler_tolerate_aliased
=
None
,
):
if
destroyhandler_tolerate_aliased
is
None
:
destroyhandler_tolerate_aliased
=
[]
self
.
idx_list
=
list
(
map
(
index_vars_to_types
,
idx_list
))
self
.
inplace
=
inplace
if
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
destroyhandler_tolerate_aliased
=
list
(
destroyhandler_tolerate_aliased
)
self
.
set_instead_of_inc
=
set_instead_of_inc
def
__hash__
(
self
):
idx_list
=
tuple
(
(
entry
.
start
,
entry
.
stop
,
entry
.
step
)
if
isinstance
(
entry
,
slice
)
else
entry
for
entry
in
self
.
idx_list
)
return
hash
((
type
(
self
),
idx_list
,
self
.
inplace
,
self
.
set_instead_of_inc
))
def
__str__
(
self
):
name
=
"SetSubtensor"
if
self
.
set_instead_of_inc
else
"IncSubtensor"
return
f
"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def
make_node
(
self
,
x
,
y
,
*
inputs
):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs: TODO WRITEME
"""
x
,
y
=
map
(
as_tensor_variable
,
[
x
,
y
])
if
y
.
ndim
>
x
.
ndim
:
raise
ValueError
(
f
"Trying to increment a {int(x.ndim)}-dimensional "
f
"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs
=
tuple
(
map
(
as_nontensor_scalar
,
inputs
))
idx_list
=
list
(
self
.
idx_list
)
if
len
(
idx_list
)
>
x
.
type
.
ndim
:
raise
IndexError
(
"too many indices for array"
)
input_types
=
get_slice_elements
(
idx_list
,
lambda
entry
:
isinstance
(
entry
,
Type
)
)
if
len
(
inputs
)
!=
len
(
input_types
):
raise
IndexError
(
"Not enough inputs to fill in the Subtensor template."
,
inputs
,
idx_list
)
for
input
,
expected_type
in
zip
(
inputs
,
input_types
,
strict
=
True
):
if
not
expected_type
.
is_super
(
input
.
type
):
raise
TypeError
(
f
"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
)
return
Apply
(
self
,
(
x
,
y
,
*
inputs
),
[
x
.
type
()])
def
decl_view
(
self
):
return
"PyArrayObject * zview = NULL;"
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
,
y
,
*
flat_indices
=
inputs
flat_indices_iterator
=
iter
(
flat_indices
)
indices
=
tuple
(
(
next
(
flat_indices_iterator
)
if
isinstance
(
entry
,
int
)
if
isinstance
(
entry
,
Type
)
else
slice
(
None
if
entry
.
start
is
None
else
next
(
flat_indices_iterator
),
None
if
entry
.
stop
is
None
else
next
(
flat_indices_iterator
),
...
...
@@ -1697,18 +1992,17 @@ class IncSubtensor(BaseSubtensor, COp):
return
[
None
]
# Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those
_x
,
_y
,
*
index_variables
=
inputs
return
self
(
eval_points
[
0
],
eval_points
[
1
],
*
index_variables
,
return_list
=
True
)
return
self
(
eval_points
[
0
],
eval_points
[
1
],
*
inputs
[
2
:],
return_list
=
True
)
def
connection_pattern
(
self
,
node
):
_x
,
_y
,
*
index_variables
=
node
.
inputs
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
index_variables
)]
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
2
:])]
return
rval
def
grad
(
self
,
inputs
,
grads
):
(
g_output
,)
=
grads
x
,
y
,
*
index_variables
=
inputs
x
,
y
=
inputs
[:
2
]
idx_list
=
inputs
[
2
:]
if
x
.
dtype
in
discrete_dtypes
:
# The output dtype is the same as x
...
...
@@ -1722,25 +2016,25 @@ class IncSubtensor(BaseSubtensor, COp):
else
:
if
self
.
set_instead_of_inc
:
gx
=
set_subtensor
(
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
i
ndex_variables
),
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
i
dx_list
),
pytensor
.
tensor
.
zeros_like
(
y
),
)
else
:
gx
=
g_output
gy
=
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
i
ndex_variables
)
gy
=
Subtensor
(
idx_list
=
self
.
idx_list
)(
g_output
,
*
i
dx_list
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
i
ndex_variables
)))]
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
i
dx_list
)))]
class
IncSubtensorPrinter
(
SubtensorPrinter
):
def
process
(
self
,
r
,
pstate
):
x
,
y
,
*
index_variable
s
=
r
.
owner
.
inputs
x
,
_y
,
*
idx_arg
s
=
r
.
owner
.
inputs
res
=
self
.
_process
(
r
.
owner
.
op
.
idx_list
,
[
x
,
*
i
ndex_variable
s
],
pstate
)
res
=
self
.
_process
(
r
.
owner
.
op
.
idx_list
,
[
x
,
*
i
dx_arg
s
],
pstate
)
with
set_precedence
(
pstate
,
1000
):
y_str
=
pstate
.
pprinter
.
process
(
y
,
pstate
)
y_str
=
pstate
.
pprinter
.
process
(
r
.
owner
.
inputs
[
1
]
,
pstate
)
if
r
.
owner
.
op
.
set_instead_of_inc
:
res
=
f
"set_subtensor({res}, {y_str})"
...
...
@@ -1801,13 +2095,9 @@ class AdvancedSubtensor1(COp):
# sparse_grad doesn't go in here since it only affects the output
# of the grad() method.
__props__
=
()
idx_list
=
(
0
,)
_f16_ok
=
True
check_input
=
False
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
__init__
(
self
,
sparse_grad
=
False
):
self
.
sparse_grad
=
sparse_grad
...
...
@@ -1831,8 +2121,7 @@ class AdvancedSubtensor1(COp):
output_storage
[
0
][
0
]
=
x
.
take
(
i
,
axis
=
0
,
out
=
None
)
def
connection_pattern
(
self
,
node
):
_x
,
*
index_variables
=
node
.
inputs
rval
=
[[
True
],
*
([
False
]
for
_
in
index_variables
)]
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
...
...
@@ -1862,8 +2151,7 @@ class AdvancedSubtensor1(COp):
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
return
[
None
]
_x
,
*
index_variables
=
inputs
return
self
.
make_node
(
eval_points
[
0
],
*
index_variables
)
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
x
,
ilist
=
ishapes
...
...
@@ -1957,17 +2245,13 @@ class AdvancedSubtensor1(COp):
advanced_subtensor1
=
AdvancedSubtensor1
()
class
AdvancedIncSubtensor1
(
BaseSubtensor
,
COp
):
class
AdvancedIncSubtensor1
(
COp
):
"""
Increments a subtensor using advanced slicing (list of index).
"""
__props__
=
(
"inplace"
,
"set_instead_of_inc"
,
)
idx_list
=
(
0
,)
__props__
=
(
"inplace"
,
"set_instead_of_inc"
)
check_input
=
False
params_type
=
ParamsType
(
inplace
=
ps
.
bool
,
set_instead_of_inc
=
ps
.
bool
)
...
...
@@ -1983,20 +2267,8 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp):
if
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
def
__hash__
(
self
):
return
hash
(
(
type
(
self
),
self
.
inplace
,
self
.
set_instead_of_inc
,
)
)
def
clone_inplace
(
self
):
return
self
.
__class__
(
inplace
=
True
,
set_instead_of_inc
=
self
.
set_instead_of_inc
,
)
return
self
.
__class__
(
inplace
=
True
,
set_instead_of_inc
=
self
.
set_instead_of_inc
)
def
__str__
(
self
):
if
self
.
inplace
:
...
...
@@ -2222,8 +2494,7 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp):
def
R_op
(
self
,
inputs
,
eval_points
):
if
None
in
eval_points
[:
2
]:
return
[
None
]
_x
,
_y
,
*
index_variables
=
inputs
return
self
.
make_node
(
eval_points
[
0
],
eval_points
[
1
],
*
index_variables
)
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
eval_points
[
1
],
*
inputs
[
2
:])
.
outputs
def
connection_pattern
(
self
,
node
):
rval
=
[[
True
],
[
True
],
[
False
]]
...
...
@@ -2256,8 +2527,15 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1()
advanced_set_subtensor1
=
AdvancedIncSubtensor1
(
set_instead_of_inc
=
True
)
def
as_tensor_index_variable
(
idx
):
"""Convert index to Variable form for advanced indexing."""
def
as_index_variable
(
idx
):
if
idx
is
None
:
return
NoneConst
.
clone
()
if
isinstance
(
idx
,
slice
):
return
make_slice
(
idx
)
if
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
SliceType
):
return
idx
if
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
NoneTypeT
):
return
idx
idx
=
as_tensor_variable
(
idx
)
if
idx
.
type
.
dtype
not
in
discrete_dtypes
:
raise
TypeError
(
"index must be integers or a boolean mask"
)
...
...
@@ -2269,45 +2547,53 @@ def as_tensor_index_variable(idx):
return
idx
class
AdvancedSubtensor
(
BaseSubtensor
,
COp
):
"""Implements NumPy's advanced indexing."""
__props__
=
(
"idx_list"
,)
__hash__
=
BaseSubtensor
.
__hash__
def
c_code_cache_version
(
self
):
hv
=
Subtensor
.
helper_c_code_cache_version
()
if
hv
:
return
(
3
,
hv
)
def
check_advanced_indexing_dimensions
(
input
,
idx_list
):
"""
This function checks if the index list in idx_list is correct.
If there are any boolean masks, we check if the mask has the
same shape as the input. This is enforced in NumPy 0.13.0 and
newer, but not by earlier versions. If the size is not the same,
this method raises an IndexError.
"""
dim_seen
=
0
for
index
in
idx_list
:
if
index
is
np
.
newaxis
:
# skip, does not count as an input dimension
pass
elif
isinstance
(
index
,
np
.
ndarray
)
and
index
.
dtype
==
"bool"
:
for
i
in
range
(
index
.
ndim
):
if
index
.
shape
[
i
]
!=
input
.
shape
[
dim_seen
+
i
]:
raise
IndexError
(
"boolean index did not match indexed array "
f
"along dimension {int(dim_seen + i)}; dimension is "
f
"{int(input.shape[dim_seen + i])} but "
f
"corresponding boolean dimension is {index.shape[i]}"
)
dim_seen
+=
index
.
ndim
else
:
return
()
dim_seen
+=
1
def
make_node
(
self
,
x
,
*
index_variables
):
if
len
(
index_variables
)
!=
self
.
n_index_vars
:
raise
ValueError
(
f
"Expected {self.n_index_vars} inputs, got {len(index_variables)}"
)
x
=
as_tensor_variable
(
x
)
index_variables
=
tuple
(
as_tensor_index_variable
(
a
)
for
a
in
index_variables
)
class
AdvancedSubtensor
(
Op
):
"""Implements NumPy's advanced indexing."""
idx_list
=
self
.
idx_list
if
len
(
idx_list
)
>
x
.
type
.
ndim
:
raise
IndexError
(
"too many indices for array"
)
__props__
=
()
reconstructed_indices
=
unflatten_index_variables
(
index_variables
,
idx_list
)
def
make_node
(
self
,
x
,
*
indices
):
x
=
as_tensor_variable
(
x
)
indices
=
tuple
(
map
(
as_index_variable
,
indices
))
explicit_indices
=
[]
for
idx
in
reconstructed_indices
:
if
isinstance
(
idx
,
slice
):
explicit_indices
.
append
(
idx
)
elif
hasattr
(
idx
,
"dtype"
)
and
idx
.
dtype
==
"bool"
:
new_axes
=
[]
for
idx
in
indices
:
if
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
:
if
idx
.
type
.
ndim
==
0
:
raise
NotImplementedError
(
"Indexing with scalar booleans not supported"
)
axis
=
len
(
explicit_indices
)
# Check static shape aligned
axis
=
len
(
explicit_indices
)
-
len
(
new_axes
)
indexed_shape
=
x
.
type
.
shape
[
axis
:
axis
+
idx
.
type
.
ndim
]
for
j
,
(
indexed_length
,
indexer_length
)
in
enumerate
(
zip
(
indexed_shape
,
idx
.
type
.
shape
)
...
...
@@ -2325,27 +2611,48 @@ class AdvancedSubtensor(BaseSubtensor, COp):
if
isinstance
(
idx
,
Constant
):
nonzero_indices
=
[
tensor_constant
(
i
)
for
i
in
idx
.
data
.
nonzero
()]
else
:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices
=
idx
.
nonzero
()
explicit_indices
.
extend
(
nonzero_indices
)
else
:
if
isinstance
(
idx
.
type
,
NoneTypeT
):
new_axes
.
append
(
len
(
explicit_indices
))
explicit_indices
.
append
(
idx
)
if
len
(
explicit_indices
)
>
x
.
type
.
ndim
:
if
(
len
(
explicit_indices
)
-
len
(
new_axes
)
)
>
x
.
type
.
ndim
:
raise
IndexError
(
f
"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed"
f
"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)
- len(new_axes)
} were indexed"
)
# Perform basic and advanced indexing shape inference separately
(no newaxis)
# Perform basic and advanced indexing shape inference separately
basic_group_shape
=
[]
advanced_indices
=
[]
adv_group_axis
=
None
last_adv_group_axis
=
None
if
new_axes
:
expanded_x_shape_list
=
list
(
x
.
type
.
shape
)
for
new_axis
in
new_axes
:
expanded_x_shape_list
.
insert
(
new_axis
,
1
)
expanded_x_shape
=
tuple
(
expanded_x_shape_list
)
else
:
expanded_x_shape
=
x
.
type
.
shape
for
i
,
(
idx
,
dim_length
)
in
enumerate
(
zip_longest
(
explicit_indices
,
x
.
type
.
shape
,
fillvalue
=
slice
(
None
)
)
zip_longest
(
explicit_indices
,
expanded_x_shape
,
fillvalue
=
NoneSliceConst
)
):
if
isinstance
(
idx
,
slice
):
basic_group_shape
.
append
(
slice_static_length
(
idx
,
dim_length
))
else
:
# TensorType (advanced index)
if
isinstance
(
idx
.
type
,
NoneTypeT
):
basic_group_shape
.
append
(
1
)
# New-axis
elif
isinstance
(
idx
.
type
,
SliceType
):
if
isinstance
(
idx
,
Constant
):
basic_group_shape
.
append
(
slice_static_length
(
idx
.
data
,
dim_length
))
elif
idx
.
owner
is
not
None
and
isinstance
(
idx
.
owner
.
op
,
MakeSlice
):
basic_group_shape
.
append
(
slice_static_length
(
slice
(
*
idx
.
owner
.
inputs
),
dim_length
)
)
else
:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape
.
append
(
None
)
else
:
# TensorType
# Keep track of advanced group axis
if
adv_group_axis
is
None
:
# First time we see an advanced index
...
...
@@ -2380,15 +2687,14 @@ class AdvancedSubtensor(BaseSubtensor, COp):
return
Apply
(
self
,
[
x
,
*
ind
ex_variabl
es
],
[
x
,
*
ind
ic
es
],
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
tuple
(
indexed_shape
))],
)
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
return
[
None
]
_x
,
*
index_variables
=
inputs
return
self
.
make_node
(
eval_points
[
0
],
*
index_variables
)
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
*
inputs
[
1
:])
.
outputs
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
def
is_bool_index
(
idx
):
...
...
@@ -2397,32 +2703,30 @@ class AdvancedSubtensor(BaseSubtensor, COp):
or
getattr
(
idx
,
"dtype"
,
None
)
==
"bool"
)
_x
,
*
index_variables
=
node
.
inputs
full_indices
=
unflatten_index_variables
(
index_variables
,
self
.
idx_list
)
indices
=
node
.
inputs
[
1
:]
index_shapes
=
[]
for
idx
in
full_indices
:
if
isinstance
(
idx
,
slice
):
for
idx
,
ishape
in
zip
(
indices
,
ishapes
[
1
:],
strict
=
True
):
# Mixed bool indexes are converted to nonzero entries
shape0_op
=
Shape_i
(
0
)
if
is_bool_index
(
idx
):
index_shapes
.
extend
((
shape0_op
(
nz_dim
),)
for
nz_dim
in
nonzero
(
idx
))
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
elif
isinstance
(
getattr
(
idx
,
"type"
,
None
),
SliceType
):
index_shapes
.
append
(
idx
)
else
:
shape0_op
=
Shape_i
(
0
)
if
is_bool_index
(
idx
):
index_shapes
.
extend
((
shape0_op
(
nz_dim
),)
for
nz_dim
in
nonzero
(
idx
))
else
:
input_shape_idx
=
(
index_variables
.
index
(
idx
)
+
1
)
# +1 because ishapes[0] is x
index_shapes
.
append
(
ishapes
[
input_shape_idx
])
index_shapes
.
append
(
ishape
)
res_shape
=
list
(
indexed_result_shape
(
ishapes
[
0
],
index_shapes
,
indices_are_shapes
=
True
)
)
for
i
,
res_dim_length
in
enumerate
(
res_shape
):
if
res_dim_length
is
None
:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape
[
i
]
=
Shape_i
(
i
)(
node
.
out
)
adv_indices
=
[
idx
for
idx
in
full_indices
if
not
isinstance
(
idx
,
slice
)]
adv_indices
=
[
idx
for
idx
in
indices
if
not
is_basic_idx
(
idx
)]
bool_indices
=
[
idx
for
idx
in
adv_indices
if
is_bool_index
(
idx
)]
# Special logic when the only advanced index group is of bool type.
...
...
@@ -2433,7 +2737,7 @@ class AdvancedSubtensor(BaseSubtensor, COp):
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim
=
full_
indices
.
index
(
bool_index
)
start_dim
=
indices
.
index
(
bool_index
)
res_shape
[
start_dim
]
=
bool_index
.
sum
()
assert
node
.
outputs
[
0
]
.
ndim
==
len
(
res_shape
)
...
...
@@ -2441,31 +2745,25 @@ class AdvancedSubtensor(BaseSubtensor, COp):
def
perform
(
self
,
node
,
inputs
,
out_
):
(
out
,)
=
out_
x
,
*
index_variables
=
inputs
full_indices
=
unflatten_index_variables
(
index_variables
,
self
.
idx_list
)
rval
=
x
.
__getitem__
(
tuple
(
full_indices
))
check_advanced_indexing_dimensions
(
inputs
[
0
],
inputs
[
1
:])
rval
=
inputs
[
0
]
.
__getitem__
(
tuple
(
inputs
[
1
:]))
# When there are no arrays, we are not actually doing advanced
# indexing, so __getitem__ will not return a copy.
# Since no view_map is set, we need to copy the returned value
if
not
any
(
isinstance
(
idx
,
np
.
ndarray
)
and
idx
.
ndim
>
0
for
idx
in
full_indices
isinstance
(
v
.
type
,
TensorType
)
and
v
.
ndim
>
0
for
v
in
node
.
inputs
[
1
:]
):
rval
=
rval
.
copy
()
out
[
0
]
=
rval
def
connection_pattern
(
self
,
node
):
_x
,
*
index_variables
=
node
.
inputs
rval
=
[[
True
],
*
([
False
]
for
_
in
index_variables
)]
rval
=
[[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
1
:])]
return
rval
def
grad
(
self
,
inputs
,
grads
):
(
gz
,)
=
grads
x
,
*
index_variables
=
inputs
x
=
inputs
[
0
]
if
x
.
dtype
in
discrete_dtypes
:
# The output dtype is the same as x
gx
=
x
.
zeros_like
(
dtype
=
config
.
floatX
)
...
...
@@ -2473,10 +2771,10 @@ class AdvancedSubtensor(BaseSubtensor, COp):
raise
NotImplementedError
(
"No support for complex grad yet"
)
else
:
gx
=
x
.
zeros_like
()
rest
=
inputs
[
1
:]
return
[
AdvancedIncSubtensor
(
self
.
idx_list
)(
gx
,
gz
,
*
index_variables
),
*
(
disconnected_type
()
for
_
in
range
(
len
(
index_variables
))),
advanced_inc_subtensor
(
gx
,
gz
,
*
rest
),
*
(
disconnected_type
()
for
_
in
range
(
len
(
rest
))),
]
@staticmethod
...
...
@@ -2493,7 +2791,7 @@ class AdvancedSubtensor(BaseSubtensor, COp):
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position.
output array, regardless of their o
p
riginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
...
...
@@ -2508,21 +2806,11 @@ class AdvancedSubtensor(BaseSubtensor, COp):
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
node
.
op
.
idx_list
)
return
_non_consecutive_adv_indexing
(
i
ndice
s
)
_
,
*
idxs
=
node
.
inputs
return
_non_consecutive_adv_indexing
(
i
dx
s
)
class
AdvancedSubtensorPrinter
(
SubtensorPrinter
):
def
process
(
self
,
r
,
pstate
):
return
self
.
_process
(
r
.
owner
.
op
.
idx_list
,
r
.
owner
.
inputs
,
pstate
)
pprint
.
assign
(
AdvancedSubtensor
,
AdvancedSubtensorPrinter
())
def
advanced_subtensor
(
x
,
*
index_variables
):
idx_list
,
flat_index_vars
=
flatten_index_variables
(
index_variables
)
return
AdvancedSubtensor
(
idx_list
)(
x
,
*
flat_index_vars
)
advanced_subtensor
=
AdvancedSubtensor
()
@_vectorize_node.register
(
AdvancedSubtensor
)
...
...
@@ -2542,33 +2830,30 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
return
vectorize_node_fallback
(
op
,
node
,
batch_x
,
*
batch_idxs
)
# Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if
any
(
not
isinstance
(
idx
,
TensorVariable
)
for
idx
in
idxs
):
raise
NotImplementedError
(
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else
:
return
vectorize_node_fallback
(
op
,
node
,
batch_x
,
*
batch_idxs
)
# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim
=
batch_x
.
type
.
ndim
-
x
.
type
.
ndim
new_idx_list
=
(
slice
(
None
),)
*
x_batch_ndim
+
op
.
idx_list
return
type
(
op
)(
new_idx_list
)
.
make_node
(
batch_x
,
*
batch_idxs
)
empty_slices
=
(
slice
(
None
),)
*
x_batch_ndim
return
op
.
make_node
(
batch_x
,
*
empty_slices
,
*
batch_idxs
)
class
AdvancedIncSubtensor
(
BaseSubtensor
,
Op
):
class
AdvancedIncSubtensor
(
Op
):
"""Increments a subtensor using advanced indexing."""
__props__
=
(
"idx_list"
,
"inplace"
,
"set_instead_of_inc"
,
"ignore_duplicates"
,
)
__hash__
=
BaseSubtensor
.
__hash__
__props__
=
(
"inplace"
,
"set_instead_of_inc"
,
"ignore_duplicates"
)
def
__init__
(
self
,
idx_list
,
inplace
=
False
,
set_instead_of_inc
=
False
,
ignore_duplicates
=
False
,
self
,
inplace
=
False
,
set_instead_of_inc
=
False
,
ignore_duplicates
=
False
):
super
()
.
__init__
(
idx_list
)
self
.
set_instead_of_inc
=
set_instead_of_inc
self
.
inplace
=
inplace
if
inplace
:
...
...
@@ -2582,27 +2867,25 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
else
"AdvancedIncSubtensor"
)
def
make_node
(
self
,
x
,
y
,
*
index_variables
):
if
len
(
index_variables
)
!=
self
.
n_index_vars
:
raise
ValueError
(
f
"Expected {self.n_index_vars} tensor inputs but got {len(index_variables)}"
)
index_variables
=
tuple
(
as_tensor_index_variable
(
idx
)
for
idx
in
index_variables
)
def
make_node
(
self
,
x
,
y
,
*
inputs
):
x
=
as_tensor_variable
(
x
)
y
=
as_tensor_variable
(
y
)
new_inputs
=
[]
for
inp
in
inputs
:
if
isinstance
(
inp
,
list
|
tuple
):
inp
=
as_tensor_variable
(
inp
)
new_inputs
.
append
(
inp
)
return
Apply
(
self
,
[
x
,
y
,
*
index_variables
]
,
(
x
,
y
,
*
new_inputs
)
,
[
x
.
type
()],
)
def
perform
(
self
,
node
,
inputs
,
out_
):
x
,
y
,
*
ind
ex_variabl
es
=
inputs
x
,
y
,
*
ind
ic
es
=
inputs
full_indices
=
unflatten_index_variables
(
index_variables
,
self
.
idx_list
)
check_advanced_indexing_dimensions
(
x
,
indices
)
(
out
,)
=
out_
if
not
self
.
inplace
:
...
...
@@ -2611,29 +2894,28 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
out
[
0
]
=
x
if
self
.
set_instead_of_inc
:
out
[
0
][
tuple
(
full_
indices
)]
=
y
out
[
0
][
tuple
(
indices
)]
=
y
elif
self
.
ignore_duplicates
:
out
[
0
][
tuple
(
full_
indices
)]
+=
y
out
[
0
][
tuple
(
indices
)]
+=
y
else
:
np
.
add
.
at
(
out
[
0
],
tuple
(
full_
indices
),
y
)
np
.
add
.
at
(
out
[
0
],
tuple
(
indices
),
y
)
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
return
[
ishapes
[
0
]]
def
connection_pattern
(
self
,
node
):
_x
,
_y
,
*
index_variables
=
node
.
inputs
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
index_variables
)]
rval
=
[[
True
],
[
True
],
*
([
False
]
for
_
in
node
.
inputs
[
2
:])]
return
rval
def
R_op
(
self
,
inputs
,
eval_points
):
if
None
in
eval_points
[:
2
]:
return
[
None
]
_x
,
_y
,
*
index_variables
=
inputs
return
self
.
make_node
(
eval_points
[
0
],
eval_points
[
1
],
*
index_variables
)
.
outputs
return
self
.
make_node
(
eval_points
[
0
],
eval_points
[
1
],
*
inputs
[
2
:])
.
outputs
def
grad
(
self
,
inpt
,
output_gradients
):
x
,
y
,
*
index_variables
=
inpt
x
,
y
=
inpt
[:
2
]
idxs
=
inpt
[
2
:]
(
outgrad
,)
=
output_gradients
if
x
.
dtype
in
discrete_dtypes
:
# The output dtype is the same as x
...
...
@@ -2646,22 +2928,21 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
raise
NotImplementedError
(
"No support for complex grad yet"
)
else
:
if
self
.
set_instead_of_inc
:
gx
=
(
type
(
self
)(
self
.
idx_list
,
set_instead_of_inc
=
True
)
.
make_node
(
outgrad
,
y
.
zeros_like
(),
*
index_variables
)
.
outputs
[
0
]
)
gx
=
advanced_set_subtensor
(
outgrad
,
y
.
zeros_like
(),
*
idxs
)
else
:
gx
=
outgrad
gy
=
(
AdvancedSubtensor
(
self
.
idx_list
)
.
make_node
(
outgrad
,
*
index_variables
)
.
outputs
[
0
]
)
gy
=
advanced_subtensor
(
outgrad
,
*
idxs
)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
index_variables
)))]
return
[
gx
,
gy
,
*
(
disconnected_type
()
for
_
in
range
(
len
(
idxs
)))]
@staticmethod
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
warnings
.
warn
(
"Method was renamed to `non_consecutive_adv_indexing`"
,
FutureWarning
)
return
AdvancedIncSubtensor
.
non_consecutive_adv_indexing
(
node
)
@staticmethod
def
non_consecutive_adv_indexing
(
node
:
Apply
)
->
bool
:
...
...
@@ -2670,7 +2951,7 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position.
output array, regardless of their o
p
riginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
...
...
@@ -2685,257 +2966,16 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
indices
=
indices_from_subtensor
(
node
.
inputs
[
2
:],
node
.
op
.
idx_list
)
return
_non_consecutive_adv_indexing
(
indices
)
def
advanced_inc_subtensor
(
x
,
y
,
*
args
,
**
kwargs
):
idx_list
,
flat_index_vars
=
flatten_index_variables
(
args
)
return
AdvancedIncSubtensor
(
idx_list
,
**
kwargs
)(
x
,
y
,
*
flat_index_vars
)
_
,
_
,
*
idxs
=
node
.
inputs
return
_non_consecutive_adv_indexing
(
idxs
)
def
advanced_set_subtensor
(
x
,
y
,
*
args
,
**
kwargs
):
return
advanced_inc_subtensor
(
x
,
y
,
*
args
,
set_instead_of_inc
=
True
,
**
kwargs
)
class
AdvancedIncSubtensorPrinter
(
SubtensorPrinter
):
def
process
(
self
,
r
,
pstate
):
x
,
y
,
*
index_variables
=
r
.
owner
.
inputs
res
=
self
.
_process
(
r
.
owner
.
op
.
idx_list
,
[
x
,
*
index_variables
],
pstate
)
with
set_precedence
(
pstate
,
1000
):
y_str
=
pstate
.
pprinter
.
process
(
y
,
pstate
)
if
r
.
owner
.
op
.
set_instead_of_inc
:
res
=
f
"set_subtensor({res}, {y_str})"
else
:
res
=
f
"inc_subtensor({res}, {y_str})"
return
res
pprint
.
assign
(
AdvancedIncSubtensor
,
AdvancedIncSubtensorPrinter
())
def
set_subtensor
(
x
,
y
,
inplace
=
False
,
tolerate_inplace_aliasing
=
False
):
"""
Return x with the given subtensor overwritten by y.
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
.. code-block:: python
from pytensor.tensor import set_subtensor, vector
r = vector("r")
new_r = set_subtensor(r[10:], 5)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
"""
return
inc_subtensor
(
x
,
y
,
inplace
,
set_instead_of_inc
=
True
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
)
def
inc_subtensor
(
x
,
y
,
inplace
=
False
,
set_instead_of_inc
=
False
,
tolerate_inplace_aliasing
=
False
,
ignore_duplicates
=
False
,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x
=
as_tensor_variable
(
x
)
y
=
as_tensor_variable
(
y
)
if
y
.
ndim
>
x
.
ndim
:
raise
TypeError
(
f
"Trying to increment a {int(x.ndim)}-dimensional "
f
"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset
=
x
.
ndim
-
y
.
ndim
for
dim
in
range
(
y
.
ndim
):
if
x
.
broadcastable
[
dim
+
dim_offset
]
and
not
y
.
broadcastable
[
dim
]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y
=
specify_broadcastable
(
y
,
dim
)
if
x
.
owner
is
None
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
# retrieve idx_list from x.owner
if
isinstance
(
x
.
owner
.
op
,
Subtensor
):
if
tolerate_inplace_aliasing
:
destroyhandler_tolerate_aliased
=
[[
0
,
1
]]
else
:
destroyhandler_tolerate_aliased
=
[]
the_op
=
IncSubtensor
(
x
.
owner
.
op
.
idx_list
,
inplace
,
set_instead_of_inc
,
destroyhandler_tolerate_aliased
=
destroyhandler_tolerate_aliased
,
)
real_x
,
*
index_variables
=
x
.
owner
.
inputs
return
the_op
(
real_x
,
y
,
*
index_variables
)
elif
isinstance
(
x
.
owner
.
op
,
AdvancedSubtensor1
):
real_x
=
x
.
owner
.
inputs
[
0
]
ilist
=
x
.
owner
.
inputs
[
1
]
if
ignore_duplicates
:
the_op
=
AdvancedIncSubtensor
(
(
0
,),
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
ignore_duplicates
=
True
,
)
else
:
the_op
=
AdvancedIncSubtensor1
(
inplace
,
set_instead_of_inc
=
set_instead_of_inc
)
return
the_op
(
real_x
,
y
,
ilist
)
elif
isinstance
(
x
.
owner
.
op
,
AdvancedSubtensor
):
real_x
,
*
index_variables
=
x
.
owner
.
inputs
the_op
=
AdvancedIncSubtensor
(
x
.
owner
.
op
.
idx_list
,
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
ignore_duplicates
=
ignore_duplicates
,
)
return
the_op
(
real_x
,
y
,
*
index_variables
)
elif
isinstance
(
x
.
owner
.
op
,
DimShuffle
):
inner_x
=
x
.
owner
.
inputs
[
0
]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order
=
x
.
owner
.
op
.
new_order
y_order
=
[
"x"
]
*
x
.
ndim
for
i
,
v
in
enumerate
(
x_order
):
if
v
!=
"x"
and
(
v
-
dim_offset
)
>=
0
:
y_order
[
v
-
dim_offset
]
=
i
inner_incsubtensor
=
inc_subtensor
(
inner_x
,
y
.
dimshuffle
(
y_order
),
inplace
=
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
ignore_duplicates
=
ignore_duplicates
,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return
inner_incsubtensor
.
dimshuffle
(
x
.
owner
.
op
.
new_order
)
elif
isinstance
(
x
.
owner
.
op
,
Reshape
):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x
=
x
.
owner
.
inputs
[
0
]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if
y
.
ndim
>
0
:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y
=
alloc
(
y
,
*
[
x
.
shape
[
i
]
for
i
in
range
(
x
.
ndim
)])
flattened_y
=
expanded_y
.
reshape
(
inner_x
.
shape
)
else
:
flattened_y
=
y
inner_incsubtensor
=
inc_subtensor
(
inner_x
,
flattened_y
,
inplace
=
inplace
,
set_instead_of_inc
=
set_instead_of_inc
,
tolerate_inplace_aliasing
=
tolerate_inplace_aliasing
,
ignore_duplicates
=
ignore_duplicates
,
)
return
inner_incsubtensor
else
:
raise
TypeError
(
"x must be the result of a subtensor operation"
)
advanced_inc_subtensor
=
AdvancedIncSubtensor
()
advanced_set_subtensor
=
AdvancedIncSubtensor
(
set_instead_of_inc
=
True
)
advanced_inc_subtensor_nodup
=
AdvancedIncSubtensor
(
ignore_duplicates
=
True
)
advanced_set_subtensor_nodup
=
AdvancedIncSubtensor
(
set_instead_of_inc
=
True
,
ignore_duplicates
=
True
)
def
take
(
a
,
indices
,
axis
=
None
,
mode
=
"raise"
):
...
...
@@ -2981,6 +3021,39 @@ def take(a, indices, axis=None, mode="raise"):
return
a
[
full_indices
]
@_get_vector_length.register
(
Subtensor
)
# type: ignore
def
_get_vector_length_Subtensor
(
op
,
var
):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try
:
indices
=
pytensor
.
tensor
.
subtensor
.
get_idx_list
(
var
.
owner
.
inputs
,
var
.
owner
.
op
.
idx_list
)
start
=
(
None
if
indices
[
0
]
.
start
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
start
)
)
stop
=
(
None
if
indices
[
0
]
.
stop
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
stop
)
)
step
=
(
None
if
indices
[
0
]
.
step
is
None
else
get_scalar_constant_value
(
indices
[
0
]
.
step
)
)
if
start
==
stop
:
return
0
arg_len
=
get_vector_length
(
var
.
owner
.
inputs
[
0
])
return
len
(
range
(
*
slice
(
start
,
stop
,
step
)
.
indices
(
arg_len
)))
except
(
ValueError
,
NotScalarConstantError
):
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
def
slice_at_axis
(
sl
:
slice
,
axis
:
int
)
->
tuple
[
slice
,
...
]:
"""
Construct tuple of slices to slice an array in the given dimension.
...
...
pytensor/tensor/variable.py
浏览文件 @
cc6bed1a
...
...
@@ -15,8 +15,9 @@ from pytensor.scalar import (
ComplexError
,
)
from
pytensor.tensor
import
_get_vector_length
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
None
TypeT
from
pytensor.tensor.type_other
import
None
Const
from
pytensor.tensor.utils
import
hash_from_ndarray
...
...
@@ -454,14 +455,15 @@ class _tensor_py_operators:
elif
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
# Count the dimensions, check for bools and find ellipses.
ellipses
=
[]
index_dim_count
=
0
for
i
,
arg
in
enumerate
(
args
):
if
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
):
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
# no increase in index_dim_count
pass
elif
arg
is
Ellipsis
:
# no increase in index_dim_count
ellipses
.
append
(
i
)
elif
(
isinstance
(
arg
,
np
.
ndarray
|
Variable
)
...
...
@@ -503,41 +505,6 @@ class _tensor_py_operators:
self
.
ndim
-
index_dim_count
)
if
any
(
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
))
for
arg
in
args
):
expansion_axes
=
[]
new_args
=
[]
# Track dims consumed by args and inserted `None`s after ellipsis
counter
=
0
nones
=
0
for
arg
in
args
:
if
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
):
expansion_axes
.
append
(
counter
+
nones
)
# Expand here
nones
+=
1
new_args
.
append
(
slice
(
None
))
else
:
new_args
.
append
(
arg
)
consumed
=
1
if
hasattr
(
arg
,
"dtype"
)
and
arg
.
dtype
==
"bool"
:
consumed
=
arg
.
ndim
counter
+=
consumed
expanded
=
pt
.
expand_dims
(
self
,
expansion_axes
)
if
all
(
isinstance
(
arg
,
slice
)
and
arg
.
start
is
None
and
arg
.
stop
is
None
and
arg
.
step
is
None
for
arg
in
new_args
):
return
expanded
return
expanded
[
tuple
(
new_args
)]
def
is_empty_array
(
val
):
return
(
isinstance
(
val
,
tuple
|
list
)
and
len
(
val
)
==
0
)
or
(
isinstance
(
val
,
np
.
ndarray
)
and
val
.
size
==
0
...
...
@@ -553,16 +520,74 @@ class _tensor_py_operators:
for
inp
in
args
)
if
all
(
(
isinstance
(
arg
,
slice
|
int
|
float
|
np
.
number
)
or
(
hasattr
(
arg
,
"ndim"
)
and
arg
.
ndim
==
0
and
arg
.
dtype
!=
"bool"
)
)
for
arg
in
args
):
return
pt
.
subtensor
.
basic_subtensor
(
self
,
*
args
)
else
:
# Determine if advanced indexing is needed or not. The logic is
# already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used
advanced
=
False
for
i
,
arg
in
enumerate
(
args
):
if
includes_bool
(
arg
):
advanced
=
True
break
if
arg
is
not
np
.
newaxis
and
arg
is
not
NoneConst
:
try
:
pt
.
subtensor
.
index_vars_to_types
(
arg
)
except
AdvancedIndexingError
:
if
advanced
:
break
else
:
advanced
=
True
if
advanced
:
return
pt
.
subtensor
.
advanced_subtensor
(
self
,
*
args
)
else
:
if
np
.
newaxis
in
args
or
NoneConst
in
args
:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter
=
0
pattern
=
[]
new_args
=
[]
for
arg
in
args
:
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
pattern
.
append
(
"x"
)
new_args
.
append
(
slice
(
None
,
None
,
None
))
else
:
pattern
.
append
(
counter
)
counter
+=
1
new_args
.
append
(
arg
)
pattern
.
extend
(
list
(
range
(
counter
,
self
.
ndim
)))
view
=
self
.
dimshuffle
(
pattern
)
full_slices
=
True
for
arg
in
new_args
:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if
not
(
isinstance
(
arg
,
slice
)
and
(
arg
.
start
is
None
or
arg
.
start
is
NoneConst
)
and
(
arg
.
stop
is
None
or
arg
.
stop
is
NoneConst
)
and
(
arg
.
step
is
None
or
arg
.
step
is
NoneConst
)
):
full_slices
=
False
if
full_slices
:
return
view
else
:
return
view
.
__getitem__
(
tuple
(
new_args
))
else
:
return
pt
.
subtensor
.
Subtensor
(
args
)(
self
,
*
pt
.
subtensor
.
get_slice_elements
(
args
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
),
)
def
__setitem__
(
self
,
key
,
value
):
raise
TypeError
(
...
...
pytensor/xtensor/rewriting/indexing.py
浏览文件 @
cc6bed1a
...
...
@@ -2,10 +2,9 @@ from itertools import zip_longest
from
pytensor
import
as_symbolic
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.tensor
import
arange
,
specify_shape
from
pytensor.tensor
import
TensorType
,
arange
,
specify_shape
from
pytensor.tensor.subtensor
import
_non_consecutive_adv_indexing
,
inc_subtensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.indexing
import
Index
,
IndexUpdate
,
index
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
...
...
@@ -107,7 +106,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs
.
append
(
to_basic_idx
(
idx
)
)
aligned_idxs
.
append
(
idx
)
basic_idx_axis
.
append
(
out_dims
.
index
(
x_dim
))
else
:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
...
...
@@ -132,7 +131,7 @@ def _lower_index(node):
if
basic_idx_axis
:
aligned_idxs
=
[
idx
.
squeeze
(
axis
=
basic_idx_axis
)
if
(
isinstance
(
idx
,
TensorVariabl
e
)
and
idx
.
type
.
ndim
>
0
)
if
(
isinstance
(
idx
.
type
,
TensorTyp
e
)
and
idx
.
type
.
ndim
>
0
)
else
idx
for
idx
in
aligned_idxs
]
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
cc6bed1a
...
...
@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
from
pytensor.tensor.rewriting.basic
import
constant_folding
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
from
pytensor.tensor.type
import
matrix
,
values_eq_approx_always_true
,
vector
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceConstant
,
slicetype
from
tests.graph.utils
import
(
MyOp
,
MyType
,
...
...
@@ -627,6 +629,21 @@ def test_pre_constant_merge():
assert
res
==
[
o2
]
assert
o2
.
owner
.
inputs
[
2
]
is
c2
# What is this supposed to test?
ms
=
MakeSlice
()(
1
)
res
=
pre_constant_merge
(
empty_fgraph
,
[
ms
])
assert
res
==
[
ms
]
const_slice
=
SliceConstant
(
type
=
slicetype
,
data
=
slice
(
1
,
None
,
2
))
assert
isinstance
(
const_slice
,
Constant
)
adv
=
AdvancedSubtensor
()(
matrix
(),
[
2
,
3
],
const_slice
)
res
=
pre_constant_merge
(
empty_fgraph
,
adv
)
assert
res
==
[
adv
]
def
test_pre_greedy_node_rewriter
():
empty_fgraph
=
FunctionGraph
([],
[])
...
...
@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
# What exactly is this supposed to test?
ms
=
MakeSlice
()(
1
)
cst
=
pre_greedy_node_rewriter
(
empty_fgraph
,
[
constant_folding
],
ms
)
assert
isinstance
(
cst
,
SliceConstant
)
# Make sure constant of slice signature is hashable.
assert
isinstance
(
hash
(
cst
.
signature
()),
int
)
@pytest.mark.parametrize
(
"tracks"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"out_pattern"
,
[(
op2
,
"x"
),
"x"
,
1.0
])
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
compare_jax_and_py
([],
[
out_pt
],
[])
@pytest.mark.parametrize
(
"func"
,
(
pt_subtensor
.
advanced_inc_subtensor1
,
pt_subtensor
.
advanced_set_subtensor1
)
)
def
test_jax_AdvancedIncSubtensor1_runtime_broadcast
(
func
):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from
pytensor
import
function
y
=
pt
.
matrix
(
"y"
,
dtype
=
"float64"
,
shape
=
(
None
,
None
))
x
=
pt
.
zeros
((
10
,
5
))
idxs
=
np
.
repeat
(
np
.
arange
(
10
),
2
)
# 20 indices
out
=
func
(
x
,
y
,
idxs
)
f
=
function
([
y
],
out
,
mode
=
"JAX"
)
# Should work with correctly sized y
f
(
np
.
ones
((
20
,
5
)))
# Should raise for runtime broadcasting on first dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
1
,
5
)))
# Should raise for runtime broadcasting on second dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
20
,
1
)))
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
"""Setting or incrementing values with boolean indexing.
...
...
tests/link/mlx/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
compare_mlx_and_py
([],
[
out_pt
],
[])
@pytest.mark.xfail
(
reason
=
"MLX slice indices must be integers or None, dynamic slices not supported"
)
def
test_mlx_MakeSlice
():
"""Test MakeSlice operation."""
# Test slice creation
start
=
pt
.
iscalar
(
"start"
)
stop
=
pt
.
iscalar
(
"stop"
)
step
=
pt
.
iscalar
(
"step"
)
# Create a slice using MakeSlice
slice_op
=
pt_subtensor
.
MakeSlice
()
slice_pt
=
slice_op
(
start
,
stop
,
step
)
# Use simple constant array instead of arange
x_pt
=
pt
.
constant
(
np
.
arange
(
10
,
dtype
=
np
.
float32
))
out_pt
=
x_pt
[
slice_pt
]
compare_mlx_and_py
([
start
,
stop
,
step
],
[
out_pt
],
[
1
,
8
,
2
])
def
test_mlx_subtensor_edge_cases
():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -3,7 +3,9 @@ import contextlib
import
numpy
as
np
import
pytest
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
from
pytensor
import
Mode
,
as_symbolic
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
...
...
@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
inc_subtensor
,
set_subtensor
,
)
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
numba_inplace_mode
,
numba_mode
,
)
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_mode
rng
=
np
.
random
.
default_rng
(
sum
(
map
(
ord
,
"Numba subtensors"
)))
@pytest.mark.parametrize
(
"step"
,
[
None
,
1
,
2
,
-
2
,
"x"
],
ids
=
lambda
x
:
f
"step={x}"
)
@pytest.mark.parametrize
(
"stop"
,
[
None
,
10
,
"x"
],
ids
=
lambda
x
:
f
"stop={x}"
)
@pytest.mark.parametrize
(
"start"
,
[
None
,
0
,
3
,
"x"
],
ids
=
lambda
x
:
f
"start={x}"
)
def
test_slice
(
start
,
stop
,
step
):
x
=
ps
.
int64
(
"x"
)
sym_slice
=
as_symbolic
(
slice
(
x
if
start
==
"x"
else
start
,
x
if
stop
==
"x"
else
stop
,
x
if
step
==
"x"
else
step
,
)
)
no_opt_mode
=
Mode
(
linker
=
"numba"
,
optimizer
=
None
)
evaled_slice
=
sym_slice
.
eval
({
x
:
-
5
},
on_unused_input
=
"ignore"
,
mode
=
no_opt_mode
)
assert
isinstance
(
evaled_slice
,
slice
)
if
start
==
"x"
:
assert
evaled_slice
.
start
==
-
5
elif
start
is
None
and
(
evaled_slice
.
step
is
None
or
evaled_slice
.
step
>
0
):
# Numba can convert to 0 (and sometimes does) in this case
assert
evaled_slice
.
start
in
(
None
,
0
)
else
:
assert
evaled_slice
.
start
==
start
if
stop
==
"x"
:
assert
evaled_slice
.
stop
==
-
5
else
:
assert
evaled_slice
.
stop
==
stop
if
step
==
"x"
:
assert
evaled_slice
.
step
==
-
5
elif
step
is
None
:
# Numba can convert to 1 (and sometimes does) in this case
assert
evaled_slice
.
step
in
(
None
,
1
)
else
:
assert
evaled_slice
.
step
==
step
@pytest.mark.parametrize
(
"x, indices"
,
[
...
...
@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
slice
(
1
,
None
),
[[
0
,
0
],
[
0
,
0
]]),
),
# Newaxis with vector indexing
(
as_tensor
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
))),
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
),
],
)
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
...
...
@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False
,
False
,
),
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
)),
np
.
array
(
5
),
# Broadcasted scalar value
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
# Newaxis with vector indexing
False
,
False
,
),
],
)
@pytest.mark.parametrize
(
"inplace"
,
(
False
,
True
))
...
...
@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
inplace
,
):
# Need rewrite to support certain forms of advanced indexing without object mode
# Use inplace_mode when testing inplace operations to preserve inplace flag
base_mode
=
numba_inplace_mode
if
inplace
else
numba_mode
mode
=
base_mode
.
including
(
"specialize"
)
mode
=
numba_mode
.
including
(
"specialize"
)
x_pt
=
pt
.
as_tensor
(
x
)
.
type
(
"x"
)
y_pt
=
pt
.
as_tensor
(
y
)
.
type
(
"y"
)
...
...
@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig
=
x
.
copy
()
fn
(
x
,
y
)
assert
not
np
.
all
(
x
==
x_orig
)
def
test_advanced_indexing_with_newaxis_fallback_obj_mode
():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x
=
pt
.
matrix
(
"x"
)
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
.
inc
(
5
)
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedIncSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
cc6bed1a
...
...
@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
# Save original value to restore later
original_value
=
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
try
:
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
1
with
pytest
.
warns
(
FutureWarning
,
match
=
"tensor__insert_inplace_optimizer_validate_nb config is deprecated"
,
):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
finally
:
# Restore original value to avoid affecting other tests
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
original_value
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
1
with
pytest
.
warns
(
FutureWarning
,
match
=
"tensor__insert_inplace_optimizer_validate_nb config is deprecated"
,
):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
tensor4
,
vector
,
)
from
pytensor.tensor.type_other
import
make_slice
from
tests
import
unittest_tools
as
utt
from
tests.unittest_tools
import
create_pytensor_param
...
...
@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
# `AdvancedSubtensor`, two indices, one slice, convert
# `AdvancedSubtensor`, two indices, one s
ymbolic s
lice, convert
x
=
pt
.
matrix
(
"x"
)
indices
=
(
pt
.
as_tensor_variable
(
np
.
array
(
[
1
]
,
np
.
int64
)),
slice
(
None
,
10
),
pt
.
as_tensor_variable
(
np
.
array
(
1
,
np
.
int64
)),
make_slice
(
slice
(
None
,
10
)
),
)
z
=
x
[
indices
]
...
...
@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
z_fn
=
pytensor
.
function
([
x
],
z
,
mode
=
mode
)
subtensor_node
=
z_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
subtensor_node
.
op
,
(
AdvancedSubtensor
,
AdvancedSubtensor1
)
)
assert
isinstance
(
subtensor_node
.
op
,
AdvancedSubtensor
)
new_index
=
subtensor_node
.
inputs
[
1
]
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
...
...
@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
out
=
vectorize_graph
(
core_graph
,
replace
=
{
core_x
:
x
,
core_y
:
y
})
fn
,
ref_fn
=
self
.
compile_fn_and_ref
([
x
,
y
],
out
)
assert
self
.
has_blockwise
(
ref_fn
)
assert
not
self
.
has_blockwise
(
fn
)
test_x
=
np
.
ones
(
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
test_y
=
rng
.
integers
(
1
,
10
,
size
=
y
.
type
.
shape
,
dtype
=
y
.
type
.
dtype
)
np
.
testing
.
assert_allclose
(
fn
(
test_x
,
test_y
),
ref_fn
(
test_x
,
test_y
))
...
...
@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize
(
"basic_idx"
,
[
True
,
False
],
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids
=
[
"basic_idx"
,
"adv_idx"
],
)
@pytest.mark.parametrize
(
...
...
@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
core_idx
=
pt
.
tensor
(
"idx"
,
dtype
=
int
,
shape
=
()
if
basic_idx
else
(
2
,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with a new arange slice on the batched dimensions.
# once it is paired with a
n
new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out
=
core_a
[
0
,
:,
core_idx
]
.
set
(
core_v
)
...
...
tests/tensor/rewriting/test_subtensor_lift.py
浏览文件 @
cc6bed1a
...
...
@@ -32,6 +32,7 @@ from pytensor.tensor import (
lscalars
,
matrix
,
shape
,
slicetype
,
specify_shape
,
tensor
,
tensor3
,
...
...
@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slice
(
iscalar
(),
iscalar
(),
iscalar
()
),),
(
slice
type
(
),),
),
(
matrix
(),
...
...
@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
(
lambda
x
:
x
[:,
[
0
,
1
]][
0
],
True
),
(
lambda
x
:
x
[:,
[
0
,
1
],
[
0
,
0
]][
1
:],
True
),
(
lambda
x
:
x
[:,
[[
0
,
1
],
[
0
,
0
]]][
1
:],
True
),
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
True
),
# Not supported, basic indexing on advanced indexing dim
(
lambda
x
:
x
[[
0
,
1
]][
0
],
False
),
# Not
suppor
ted, basic indexing on the right of advanced indexing
# Not
implemen
ted, basic indexing on the right of advanced indexing
(
lambda
x
:
x
[[
0
,
1
]][:,
0
],
False
),
# Not implemented, complex flavors of advanced indexing
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
5
:,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
:,
np
.
array
([
True
,
False
,
False
])][
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
],
:,
[
0
,
1
]][:,
0
],
False
),
...
...
tests/tensor/test_blockwise.py
浏览文件 @
cc6bed1a
...
...
@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback
,
)
from
pytensor.tensor.nlinalg
import
MatrixInverse
,
eig
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.op
import
default_rng
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.slinalg
import
(
...
...
@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
def
test_vectorize_node_fallback_unsupported_type
():
rng
=
default_rng
(
)
node
=
normal
(
rng
=
rng
)
.
owner
x
=
tensor
(
"x"
,
shape
=
(
2
,
6
)
)
node
=
x
[:,
[
0
,
2
,
4
]]
.
owner
with
pytest
.
raises
(
NotImplementedError
,
match
=
re
.
escape
(
'Cannot vectorize node normal_rv{"(),()->()"}('
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
),
):
vectorize_node_fallback
(
node
.
op
,
node
,
*
node
.
inputs
)
vectorize_node_fallback
(
node
.
op
,
node
,
node
.
inputs
)
def
check_blockwise_runtime_broadcasting
(
mode
):
...
...
tests/tensor/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -11,19 +11,20 @@ from numpy.testing import assert_array_equal
import
pytensor
import
pytensor.scalar
as
scal
import
pytensor.tensor.basic
as
ptb
from
pytensor
import
function
,
shared
from
pytensor.compile
import
DeepCopyOp
from
pytensor
import
function
from
pytensor.compile
import
DeepCopyOp
,
shared
from
pytensor.compile.io
import
In
from
pytensor.compile.mode
import
Mode
,
get_default_mode
from
pytensor.configdefaults
import
config
from
pytensor.gradient
import
grad
from
pytensor.graph
import
Constant
from
pytensor.graph.basic
import
equal_computations
from
pytensor.graph.op
import
get_test_value
from
pytensor.graph.rewriting.utils
import
is_same_graph
from
pytensor.link.numba
import
NumbaLinker
from
pytensor.printing
import
pprint
from
pytensor.scalar.basic
import
as_scalar
,
int16
from
pytensor.tensor
import
as_tensor
,
constant
,
get_vector_length
,
ivector
,
vectorize
from
pytensor.tensor
import
as_tensor
,
constant
,
get_vector_length
,
vectorize
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
exp
,
isinf
,
lt
,
switch
...
...
@@ -32,6 +33,7 @@ from pytensor.tensor.shape import specify_broadcastable, specify_shape
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIndexingError
,
AdvancedSubtensor
,
AdvancedSubtensor1
,
IncSubtensor
,
...
...
@@ -47,6 +49,7 @@ from pytensor.tensor.subtensor import (
flip
,
get_canonical_form_slice
,
inc_subtensor
,
index_vars_to_types
,
indexed_result_shape
,
set_subtensor
,
slice_at_axis
,
...
...
@@ -77,7 +80,13 @@ from pytensor.tensor.type import (
tensor5
,
vector
,
)
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
(
NoneConst
,
SliceConstant
,
as_symbolic_slice
,
make_slice
,
slicetype
,
)
from
tests
import
unittest_tools
as
utt
from
tests.tensor.utils
import
inplace_func
,
integers_ranged
,
random
...
...
@@ -97,12 +106,20 @@ def test_as_index_literal():
assert
res
==
slice
(
1
,
None
)
res
=
as_index_literal
(
slice
(
None
,
None
,
ptb
.
as_tensor
(
2
)))
assert
res
==
slice
(
None
,
None
,
2
)
res
=
as_index_literal
(
SliceConstant
(
slicetype
,
slice
(
None
)))
assert
res
==
slice
(
None
)
res
=
as_index_literal
(
make_slice
(
None
,
ptb
.
as_tensor
(
1
)))
assert
res
==
slice
(
None
,
1
)
res
=
as_index_literal
(
ptb
.
as_tensor
(
2
))
assert
res
==
2
res
=
as_index_literal
(
np
.
newaxis
)
assert
res
is
np
.
newaxis
res
=
as_index_literal
(
NoneConst
)
assert
res
is
np
.
newaxis
res
=
as_index_literal
(
NoneConst
.
clone
())
assert
res
is
np
.
newaxis
class
TestGetCanonicalFormSlice
:
...
...
@@ -111,6 +128,8 @@ class TestGetCanonicalFormSlice:
[
NoneConst
,
None
,
as_symbolic_slice
(
slice
(
3
,
7
,
2
)),
as_symbolic_slice
(
slice
(
3
,
int16
(),
2
)),
vector
(),
],
)
...
...
@@ -118,19 +137,6 @@ class TestGetCanonicalFormSlice:
with
pytest
.
raises
(
ValueError
,
match
=
"not a supported slice"
):
get_canonical_form_slice
(
idx
,
5
)
@pytest.mark.parametrize
(
"idx,expected_direction"
,
[
(
slice
(
3
,
7
,
2
),
1
),
(
slice
(
None
,
None
),
1
),
(
slice
(
None
,
None
,
-
1
),
-
1
),
],
)
def
test_python_slice_support
(
self
,
idx
,
expected_direction
):
result
,
direction
=
get_canonical_form_slice
(
idx
,
10
)
assert
isinstance
(
result
,
slice
)
assert
direction
==
expected_direction
def
test_scalar_constant
(
self
):
a
=
as_scalar
(
0
)
length
=
lscalar
()
...
...
@@ -402,7 +408,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
f
=
inplace_func
([],
t
,
mode
=
mode
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo_
=
[
node
for
node
in
topo
if
not
isinstance
(
node
.
op
,
DeepCopyOp
)]
assert
len
(
topo_
)
==
length
,
f
.
dprint
()
assert
len
(
topo_
)
==
length
if
length
==
1
:
assert
isinstance
(
topo_
[
0
]
.
op
,
op_type
)
tval
=
f
()
...
...
@@ -617,7 +623,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
(
3
,
DimShuffle
,
np
.
index_exp
[
...
,
[
0
,
2
,
3
]]),
(
1
,
DimShuffle
,
np
.
index_exp
[
np
.
newaxis
,
...
]),
(
4
if
config
.
mode
==
"FAST_COMPILE"
else
3
,
1
,
AdvancedSubtensor
,
np
.
index_exp
[
...
,
np
.
newaxis
,
[
1
,
2
]],
),
...
...
@@ -1961,7 +1967,7 @@ class TestAdvancedSubtensor:
x
=
self
.
shared
(
x_val
,
name
=
"x"
)
y
=
tensor
(
dtype
=
"float32"
,
shape
=
(
None
,)
*
len
(
y_val
.
shape
),
name
=
"y"
)
sym_idx
=
[
ptb
.
as_tensor_variable
(
ix
)
for
ix
in
idx
]
expr
=
advanced_inc_subtensor
(
x
,
y
,
*
sym_idx
,
inplace
=
inplace
)
expr
=
AdvancedIncSubtensor
(
inplace
=
inplace
)(
x
,
y
,
*
sym_idx
)
f
=
pytensor
.
function
(
[
y
],
expr
,
mode
=
self
.
mode
.
excluding
(
"inplace"
),
accept_inplace
=
inplace
)
...
...
@@ -2297,29 +2303,20 @@ class TestAdvancedSubtensor:
def
test_adv_sub_slice
(
self
):
# Reported in https://github.com/Theano/Theano/issues/5898
var
=
self
.
shared
(
np
.
zeros
([
3
,
3
],
dtype
=
config
.
floatX
))
slc
=
slicetype
()
f
=
pytensor
.
function
([
slc
],
var
[
slc
],
mode
=
self
.
mode
)
s
=
slice
(
1
,
3
)
assert
f
(
s
)
.
shape
==
(
2
,
3
)
# Test with scalar variables for slice boundaries
start
=
lscalar
(
"start"
)
stop
=
lscalar
(
"stop"
)
# Create sliced output
f
=
pytensor
.
function
([
start
,
stop
],
var
[
start
:
stop
],
mode
=
self
.
mode
)
result
=
f
(
1
,
3
)
assert
result
.
shape
==
(
2
,
3
)
f_shape0
=
pytensor
.
function
([
slc
],
var
[
slc
]
.
shape
[
0
],
mode
=
self
.
mode
)
assert
f_shape0
(
s
)
==
2
f_shape0
=
pytensor
.
function
(
[
start
,
stop
],
var
[
start
:
stop
]
.
shape
[
0
],
mode
=
self
.
mode
)
assert
f_shape0
(
1
,
3
)
==
2
f_shape1
=
pytensor
.
function
(
[
start
,
stop
],
var
[
start
:
stop
]
.
shape
[
1
],
mode
=
self
.
mode
)
f_shape1
=
pytensor
.
function
([
slc
],
var
[
slc
]
.
shape
[
1
],
mode
=
self
.
mode
)
assert
not
any
(
isinstance
(
node
.
op
,
AdvancedSubtensor
)
for
node
in
f_shape1
.
maker
.
fgraph
.
toposort
()
)
assert
f_shape1
(
1
,
3
)
==
3
assert
f_shape1
(
s
)
==
3
def
test_adv_grouped
(
self
):
# Reported in https://github.com/Theano/Theano/issues/6152
...
...
@@ -2801,8 +2798,8 @@ class TestInferShape(utt.InferShapeTester):
def
test_advanced_subtensor_constant_slice
(
self
):
x
=
dmatrix
(
"x"
)
# Use Python slice directly instead of as_symbolic(slice(
))
constant_slice
=
slice
(
1
,
None
,
None
)
constant_slice
=
pytensor
.
as_symbolic
(
slice
(
1
,
None
,
None
))
assert
isinstance
(
constant_slice
,
Constant
)
adv_indices
=
ptb
.
constant
(
np
.
zeros
((
2
,
3
)),
dtype
=
"int"
)
y
=
advanced_subtensor
(
x
,
constant_slice
,
adv_indices
)
assert
tuple
(
y
.
shape
.
eval
({
x
:
np
.
zeros
((
10
,
10
))}))
==
(
9
,
2
,
3
)
...
...
@@ -2811,7 +2808,7 @@ class TestInferShape(utt.InferShapeTester):
@config.change_flags
(
compute_test_value
=
"raise"
)
def
test_basic_shape
():
test_shape
=
(
5
,
4
)
test_indices
=
(
slice
(
1
,
3
,
None
),)
# Python slice instead of make_slice(
)
test_indices
=
(
make_slice
(
1
,
3
,
None
),
)
res
=
basic_shape
(
test_shape
,
test_indices
)
assert
get_test_value
(
res
)
==
(
2
,)
...
...
@@ -2849,6 +2846,18 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
slice
(
None
,
None
),
*
test_idx
[:
1
]),
),
(
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
slice
(
None
,
None
),
None
,
*
test_idx
[
1
:
2
]),
),
(
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
np
.
array
(
1
),
slice
(
None
,
None
),
None
),
),
(
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
slice
(
None
,
None
),
None
,
np
.
array
(
1
)),
),
(
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
*
test_idx
[:
1
],
slice
(
None
,
None
),
*
test_idx
[
1
:
2
]),
...
...
@@ -2857,6 +2866,10 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
*
test_idx
[:
1
],
slice
(
None
,
None
),
*
test_idx
[
1
:
2
],
slice
(
None
,
None
)),
),
(
np
.
arange
(
np
.
prod
((
5
,
6
,
7
,
8
)))
.
reshape
((
5
,
6
,
7
,
8
)),
(
*
test_idx
[:
1
],
None
,
*
test_idx
[
1
:
2
]),
),
(
np
.
arange
(
np
.
prod
((
5
,
4
)))
.
reshape
((
5
,
4
)),
([
1
,
3
,
2
],
slice
(
1
,
3
))),
(
np
.
arange
(
np
.
prod
((
5
,
4
)))
.
reshape
((
5
,
4
)),
(
slice
(
1
,
3
),
[
1
,
3
,
2
])),
(
...
...
@@ -2916,11 +2929,12 @@ def test_get_vector_length():
"indices, exp_res"
,
[
((
0
,),
"x[0]"
),
((
slice
(
None
,
2
),),
"x[:2]"
),
((
slice
(
0
,
None
),),
"x[0:]"
),
((
slice
(
0
,
2
),),
"x[0:2]"
),
((
slice
(
0
,
2
,
2
),),
"x[0:2:2]"
),
((
slice
(
0
,
2
),
0
,
slice
(
0
,
2
)),
"x[0:2, 0, 0:2]"
),
# TODO: The numbers should be printed
((
slice
(
None
,
2
),),
"x[:int64]"
),
((
slice
(
0
,
None
),),
"x[int64:]"
),
((
slice
(
0
,
2
),),
"x[int64:int64]"
),
((
slice
(
0
,
2
,
2
),),
"x[int64:int64:int64]"
),
((
slice
(
0
,
2
),
0
,
slice
(
0
,
2
)),
"x[int64:int64, 2, int64:int64]"
),
],
)
def
test_pprint_Subtensor
(
indices
,
exp_res
):
...
...
@@ -2934,7 +2948,7 @@ def test_pprint_Subtensor(indices, exp_res):
[
((
0
,),
False
,
"inc_subtensor(x[0], z)"
),
((
0
,),
True
,
"set_subtensor(x[0], z)"
),
((
slice
(
0
,
2
),),
True
,
"set_subtensor(x[
0:2
], z)"
),
((
slice
(
0
,
2
),),
True
,
"set_subtensor(x[
int64:int64
], z)"
),
],
)
def
test_pprint_IncSubtensor
(
indices
,
set_instead_of_inc
,
exp_res
):
...
...
@@ -2944,38 +2958,22 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
assert
pprint
(
y
)
==
exp_res
@pytest.mark.parametrize
(
"indices, exp_res"
,
[
# Vector index
((
ivector
(
"idx"
),),
"x[idx]"
),
# Two vector indices
((
ivector
(
"idx"
),
ivector
(
"idx2"
)),
"x[idx, idx2]"
),
# Vector index with scalar (triggers advanced indexing)
((
ivector
(
"idx"
),
0
),
"x[idx, 0]"
),
# Vector index with constant slice
((
ivector
(
"idx"
),
slice
(
0
,
5
)),
"x[idx, 0:5]"
),
],
)
def
test_pprint_AdvancedSubtensor
(
indices
,
exp_res
):
x
=
tensor4
(
"x"
)
y
=
advanced_subtensor
(
x
,
*
indices
)
assert
pprint
(
y
)
==
exp_res
def
test_index_vars_to_types
():
x
=
ptb
.
as_tensor_variable
(
np
.
array
([
True
,
False
]))
with
pytest
.
raises
(
AdvancedIndexingError
):
index_vars_to_types
(
x
)
@pytest.mark.parametrize
(
"indices, set_instead_of_inc, exp_res"
,
[
((
ivector
(
"idx"
),),
False
,
"inc_subtensor(x[idx], z)"
),
((
ivector
(
"idx"
),),
True
,
"set_subtensor(x[idx], z)"
),
((
ivector
(
"idx"
),
slice
(
None
,
5
)),
True
,
"set_subtensor(x[idx, :5], z)"
),
],
)
def
test_pprint_AdvancedIncSubtensor
(
indices
,
set_instead_of_inc
,
exp_res
):
x
=
tensor4
(
"x"
)
z
=
tensor3
(
"z"
)
y
=
advanced_inc_subtensor
(
x
,
z
,
*
indices
,
set_instead_of_inc
=
set_instead_of_inc
)
assert
pprint
(
y
)
==
exp_res
with
pytest
.
raises
(
TypeError
):
index_vars_to_types
(
1
)
res
=
index_vars_to_types
(
iscalar
)
assert
isinstance
(
res
,
scal
.
ScalarType
)
x
=
scal
.
constant
(
1
,
dtype
=
np
.
uint8
)
assert
isinstance
(
x
.
type
,
scal
.
ScalarType
)
res
=
index_vars_to_types
(
x
)
assert
res
==
x
.
type
@pytest.mark.parametrize
(
...
...
@@ -3068,12 +3066,15 @@ def test_vectorize_subtensor_without_batch_indices():
(
2
,),
False
,
),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
None
]),
"(7,5,3),(2)->(7,2,1,3)"
,
(
11
,
7
,
5
,
3
),
(
2
,),
False
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
(
(
lambda
x
,
idx
:
x
[:,
idx
,
idx
,
:]),
...
...
@@ -3082,23 +3083,27 @@ def test_vectorize_subtensor_without_batch_indices():
(
2
,),
False
,
),
# (not supported, because fallback Blocwise can't handle slices)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
:,
idx
]),
"(7,5,3,5),(2)->(2,7,3)"
,
(
11
,
7
,
5
,
3
,
5
),
(
2
,),
True
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
# Core x, batched idx
((
lambda
x
,
idx
:
x
[
idx
]),
"(t1),(idx)->(tx)"
,
(
7
,),
(
11
,
2
),
True
),
# Batched x, batched idx
((
lambda
x
,
idx
:
x
[
idx
]),
"(t1),(idx)->(tx)"
,
(
11
,
7
),
(
11
,
2
),
True
),
# (not supported, because fallback Blocwise can't handle slices)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
:]),
"(t1,t2,t3),(idx)->(t1,tx,t3)"
,
(
11
,
7
,
5
,
3
),
(
11
,
2
),
True
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
],
)
...
...
@@ -3233,37 +3238,3 @@ class TestBenchmarks:
)
fn
.
vm
.
allow_gc
=
gc
benchmark
(
fn
,
x_values
)
def
test_subtensor_hash_and_eq
():
s1
=
Subtensor
(
idx_list
=
[
slice
(
None
,
None
,
None
),
0
])
s2
=
Subtensor
(
idx_list
=
[
slice
(
None
,
None
,
None
),
0
])
assert
s1
==
s2
assert
hash
(
s1
)
==
hash
(
s2
)
s3
=
AdvancedSubtensor
(
idx_list
=
[
slice
(
None
,
None
,
None
),
0
])
s4
=
AdvancedIncSubtensor
(
idx_list
=
[
slice
(
0
,
1
,
None
),
2
])
assert
s3
!=
s4
assert
hash
(
s3
)
!=
hash
(
s4
)
assert
s1
!=
s3
inc1
=
IncSubtensor
(
idx_list
=
[
slice
(
None
)],
inplace
=
True
,
destroyhandler_tolerate_aliased
=
[(
0
,
1
)]
)
inc2
=
IncSubtensor
(
idx_list
=
[
slice
(
None
)],
inplace
=
True
,
destroyhandler_tolerate_aliased
=
[(
0
,
1
)]
)
inc3
=
IncSubtensor
(
idx_list
=
[
slice
(
None
)],
inplace
=
True
,
destroyhandler_tolerate_aliased
=
[(
0
,
2
)]
)
assert
inc1
==
inc2
assert
hash
(
inc1
)
==
hash
(
inc2
)
assert
inc1
!=
inc3
if
hash
(
inc1
)
==
hash
(
inc3
):
assert
inc1
==
inc3
s_mix1
=
Subtensor
(
idx_list
=
[
0
,
slice
(
None
),
slice
(
None
,
1
)])
s_mix2
=
Subtensor
(
idx_list
=
[
0
,
slice
(
None
),
slice
(
None
,
1
)])
assert
s_mix1
==
s_mix2
assert
hash
(
s_mix1
)
==
hash
(
s_mix2
)
tests/tensor/test_type_other.py
浏览文件 @
cc6bed1a
...
...
@@ -4,8 +4,30 @@ import pytensor
from
pytensor
import
as_symbolic
from
pytensor.graph.basic
import
Constant
from
pytensor.tensor.math
import
argmax
from
pytensor.tensor.type
import
vector
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type_other
import
(
MakeSlice
,
NoneConst
,
NoneTypeT
,
SliceConstant
,
SliceType
,
make_slice
,
)
def
test_SliceType
():
st
=
SliceType
()
assert
st
==
st
.
clone
()
def
test_make_slice_merge
():
# In the past, this was crahsing during compilation.
i
=
iscalar
()
s1
=
make_slice
(
0
,
i
)
s2
=
make_slice
(
0
,
i
)
f
=
pytensor
.
function
([
i
],
[
s1
,
s2
])
nodes
=
f
.
maker
.
fgraph
.
apply_nodes
assert
len
([
n
for
n
in
nodes
if
isinstance
(
n
.
op
,
MakeSlice
)])
==
1
def
test_none_Constant
():
...
...
@@ -25,6 +47,8 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past.
import
pickle
import
pytensor
x
=
vector
(
"x"
)
y
=
argmax
(
x
)
kwargs
=
{}
...
...
@@ -36,18 +60,11 @@ def test_none_Constant():
def
test_as_symbolic
():
# Remove this when xtensor is not using symbolic slices
from
pytensor.tensor.type
import
iscalar
from
pytensor.tensor.type_other
import
SliceConstant
,
slicetype
res
=
as_symbolic
(
None
)
assert
res
is
NoneConst
res
=
as_symbolic
(
slice
(
iscalar
()))
assert
res
.
owner
.
op
==
make_slice
res
=
as_symbolic
(
slice
(
1
,
2
))
assert
isinstance
(
res
,
SliceConstant
)
assert
res
.
type
==
slicetype
assert
res
.
data
==
slice
(
1
,
2
)
i
=
iscalar
()
res
=
as_symbolic
(
slice
(
i
))
assert
res
.
owner
is
not
None
tests/tensor/test_variable.py
浏览文件 @
cc6bed1a
...
...
@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar
,
tensor3
,
)
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneConst
from
pytensor.tensor.variable
import
(
DenseTensorConstant
,
DenseTensorVariable
,
...
...
@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z
=
x
[:,
i
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlice
,
AdvancedSubtensor
]
z
=
x
[
...
,
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
DimShuffl
e
,
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlic
e
,
AdvancedSubtensor
]
z
=
x
[
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论