Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc6bed1a
提交
cc6bed1a
authored
2月 28, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
3月 01, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Refactor AdvancedSubtensor"
This reverts commit
db7fa079
.
上级
03afa5bb
全部展开
显示空白字符变更
内嵌
并排
正在显示
26 个修改的文件
包含
435 行增加
和
231 行删除
+435
-231
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-2
subtensor.py
pytensor/link/jax/dispatch/subtensor.py
+33
-3
subtensor.py
pytensor/link/mlx/dispatch/subtensor.py
+18
-5
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+0
-0
subtensor.py
pytensor/link/pytorch/dispatch/subtensor.py
+19
-10
rewriting.py
pytensor/scan/rewriting.py
+24
-4
basic.py
pytensor/tensor/basic.py
+4
-4
basic.py
pytensor/tensor/random/rewriting/basic.py
+14
-6
shape.py
pytensor/tensor/rewriting/shape.py
+3
-5
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+0
-0
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+27
-19
uncanonicalize.py
pytensor/tensor/rewriting/uncanonicalize.py
+51
-33
subtensor.py
pytensor/tensor/subtensor.py
+0
-0
variable.py
pytensor/tensor/variable.py
+72
-47
indexing.py
pytensor/xtensor/rewriting/indexing.py
+3
-4
test_basic.py
tests/graph/rewriting/test_basic.py
+26
-0
test_subtensor.py
tests/link/jax/test_subtensor.py
+0
-31
test_subtensor.py
tests/link/mlx/test_subtensor.py
+21
-0
test_subtensor.py
tests/link/numba/test_subtensor.py
+62
-20
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+0
-6
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+16
-6
test_subtensor_lift.py
tests/tensor/rewriting/test_subtensor_lift.py
+4
-3
test_blockwise.py
tests/tensor/test_blockwise.py
+4
-8
test_subtensor.py
tests/tensor/test_subtensor.py
+0
-0
test_type_other.py
tests/tensor/test_type_other.py
+29
-12
test_variable.py
tests/tensor/test_variable.py
+3
-3
没有找到文件。
pytensor/graph/destroyhandler.py
浏览文件 @
cc6bed1a
...
...
@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
}
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_aliased"
,
()
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
)
assert
isinstance
(
tolerate_aliased
,
tuple
|
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
{
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
}
...
...
pytensor/link/jax/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
BOOLEAN_MASK_ERROR
=
"""JAX does not support resizing arrays with boolean
...
...
@@ -34,8 +35,10 @@ slice length.
@jax_funcify.register
(
AdvancedSubtensor
)
@jax_funcify.register
(
AdvancedSubtensor1
)
def
jax_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register
(
IncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor1
)
def
jax_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
...
...
@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return
jax_fn
(
x
,
indices
,
y
)
return
incsubtensor
@jax_funcify.register
(
AdvancedIncSubtensor
)
def
jax_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
else
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
advancedincsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
return
jax_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
@jax_funcify.register
(
MakeSlice
)
def
jax_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/mlx/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
@mlx_funcify.register
(
Subtensor
)
def
mlx_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
[
int
(
element
)
for
element
in
ilists
],
op
.
idx_list
)
indices
=
indices_from_subtensor
([
int
(
element
)
for
element
in
ilists
],
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register
(
AdvancedSubtensor
)
@mlx_funcify.register
(
AdvancedSubtensor1
)
def
mlx_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
advanced_subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register
(
IncSubtensor
)
@mlx_funcify.register
(
AdvancedIncSubtensor1
)
def
mlx_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
mlx_fn
(
x
,
indices
,
y
):
...
...
@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x
[
indices
]
+=
y
return
x
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
...
...
@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
mlx_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
@mlx_funcify.register
(
MakeSlice
)
def
mlx_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/link/pytorch/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceType
def
check_negative_steps
(
indices
):
...
...
@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return
subtensor
@pytorch_funcify.register
(
MakeSlice
)
def
pytorch_funcify_makeslice
(
op
,
**
kwargs
):
def
makeslice
(
start
,
stop
,
step
):
# Torch does not like numpy integers in indexing slices
return
slice
(
None
if
start
is
None
else
int
(
start
),
None
if
stop
is
None
else
int
(
stop
),
None
if
step
is
None
else
int
(
step
),
)
return
makeslice
@pytorch_funcify.register
(
AdvancedSubtensor1
)
@pytorch_funcify.register
(
AdvancedSubtensor
)
def
pytorch_funcify_AdvSubtensor
(
op
,
node
,
**
kwargs
):
def
advsubtensor
(
x
,
*
indices
):
indices
=
indices_from_subtensor
(
indices
,
op
.
idx_list
)
check_negative_steps
(
indices
)
return
x
[
indices
]
...
...
@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register
(
AdvancedIncSubtensor
)
@pytorch_funcify.register
(
AdvancedIncSubtensor1
)
def
pytorch_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
op
.
idx_list
inplace
=
op
.
inplace
ignore_duplicates
=
getattr
(
op
,
"ignore_duplicates"
,
False
)
if
op
.
set_instead_of_inc
:
def
adv_set_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
def
adv_set_subtensor
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
...
@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif
ignore_duplicates
:
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
...
@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
adv_inc_subtensor_no_duplicates
else
:
if
any
(
isinstance
(
entry
,
slice
)
for
entry
in
idx_list
):
if
any
(
isinstance
(
idx
.
type
,
SliceType
)
for
idx
in
node
.
inputs
[
2
:]
):
raise
NotImplementedError
(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def
adv_inc_subtensor
(
x
,
y
,
*
flattened_indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
# Not needed because slices aren't supported in this path
def
adv_inc_subtensor
(
x
,
y
,
*
indices
):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if
not
inplace
:
x
=
x
.
clone
()
...
...
pytensor/scan/rewriting.py
浏览文件 @
cc6bed1a
...
...
@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from
pytensor.tensor.subtensor
import
(
IncSubtensor
,
Subtensor
,
basic_subtensor
,
get_canonical_form_slice
,
get_idx_list
,
get_slice_elements
,
set_subtensor
,
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if
not
(
isinstance
(
op
,
IncSubtensor
)
and
op
.
set_instead_of_inc
and
op
.
idx_list
==
(
slice
(
None
,
0
),)
and
op
.
idx_list
==
[
slice
(
None
,
ps
.
int64
)]
):
return
False
...
...
@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else
:
# 2.3.1 extract idx list of subtensor
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3.2 extract the begin/end of the first dimension
if
i
>=
op_info
.
n_mit_mot
:
...
...
@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break
else
:
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
store_steps
[
i
]
=
0
break
if
isinstance
(
this_slice
[
0
],
slice
):
start
=
this_slice
[
0
]
.
start
...
...
@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,
*
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
fslice
,
*
old_slices
[
1
:])
subtens
=
Subtensor
(
nw_slice
)
# slice inputs
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
...
...
@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
nw_slice
=
(
sanitize
(
position
),
*
old_slices
[
1
:])
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
*
nw_slice
)
subtens
=
Subtensor
(
nw_slice
)
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
...
...
pytensor/tensor/basic.py
浏览文件 @
cc6bed1a
...
...
@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.type
import
HasShape
,
Type
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
...
...
@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and
len
(
v
.
owner
.
op
.
idx_list
)
==
1
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op
=
owner
.
op
idx_list
=
op
.
idx_list
idx
=
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
cc6bed1a
...
...
@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
...
...
@@ -237,14 +237,19 @@ def local_subtensor_rv_lift(fgraph, node):
return
False
# Parse indices
if
isinstance
(
subtensor_op
,
Subtensor
|
AdvancedSubtensor
):
if
isinstance
(
subtensor_op
,
Subtensor
):
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
subtensor_op
.
idx_list
)
else
:
indices
=
node
.
inputs
[
1
:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
for
idx
in
indices
):
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
indices
):
return
False
# Check that indexing does not act on support dims
...
...
@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices
[
batch_ndims
:],
)
for
idx
in
supp_indices
:
if
idx
!=
slice
(
None
):
if
not
(
isinstance
(
idx
.
type
,
SliceType
)
and
all
(
isinstance
(
i
.
type
,
NoneTypeT
)
for
i
in
idx
.
owner
.
inputs
)
):
return
False
n_discarded_idxs
=
len
(
supp_indices
)
indices
=
indices
[:
-
n_discarded_idxs
]
...
...
@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim
if
curr_dim
in
bcast_param_dims
:
# Slice indexing, keep degenerate dim by none-slicing
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
)
or
isinstance
(
idx
.
type
,
SliceType
)
:
batch_indices
.
append
(
slice
(
None
))
# Integer indexing, drop degenerate dim by 0-indexing
else
:
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
cc6bed1a
...
...
@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
)
from
pytensor.graph.traversal
import
ancestors
from
pytensor.graph.utils
import
InconsistencyError
,
get_variable_trace_string
from
pytensor.scalar
import
ScalarType
from
pytensor.tensor.basic
import
(
MakeVector
,
as_tensor_variable
,
...
...
@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
if
isinstance
(
var
.
owner
.
op
,
Shape_i
):
return
(
var
.
owner
.
op
.
i
==
i
)
and
(
var
.
owner
.
inputs
[
0
]
==
x
)
# type: ignore
# Match Subtensor((
int,))(Shape(input), i) - single integer index into shape
# Match Subtensor((
ScalarType,))(Shape(input), i)
if
isinstance
(
var
.
owner
.
op
,
Subtensor
):
idx_entry
=
(
var
.
owner
.
op
.
idx_list
[
0
]
if
len
(
var
.
owner
.
op
.
idx_list
)
==
1
else
None
)
return
(
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len
(
var
.
owner
.
op
.
idx_list
)
==
1
and
isinstance
(
idx_entry
,
int
)
and
isinstance
(
var
.
owner
.
op
.
idx_list
[
0
],
ScalarType
)
# Check we are indexing on the shape of x
and
var
.
owner
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
var
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/tensor/rewriting/subtensor_lift.py
浏览文件 @
cc6bed1a
...
...
@@ -8,6 +8,7 @@ from pytensor import Variable
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
,
FunctionGraph
,
node_rewriter
,
vectorize_graph
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
copy_stack_trace
from
pytensor.scalar
import
basic
as
ps
from
pytensor.tensor.basic
import
(
Alloc
,
Join
,
...
...
@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize
,
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.subtensor
import
register_useless
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
,
register_useless
from
pytensor.tensor.shape
import
(
Shape
,
SpecifyShape
,
...
...
@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
)
->
bool
:
if
isinstance
(
axis
,
int
):
axis
=
(
axis
,)
return
any
(
ax
<
len
(
idxs
)
and
not
i
dxs
[
ax
]
==
slice
(
None
)
for
ax
in
axis
)
return
any
(
ax
<
len
(
idxs
)
and
not
i
s_full_slice
(
idxs
[
ax
]
)
for
ax
in
axis
)
def
_lift_subtensor_non_axis
(
...
...
@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable
:
TensorVariable
,
)
->
None
|
list
[
TensorVariable
]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
dx
==
slice
(
None
)]
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
s_full_slice
(
idx
)]
if
len
(
real_indices
)
>
1
and
variable
.
type
.
ndim
>
1
:
# Split the subtensor
idx_to_keep
=
idx_tuple
[
axis
]
...
...
@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if
len
(
idx_tuple
)
>
batch_ndim
:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices
,
core_indices
=
idx_tuple
[:
batch_ndim
],
idx_tuple
[
batch_ndim
:]
if
all
(
i
dx
==
slice
(
None
)
for
idx
in
batch_indices
):
if
all
(
i
s_full_slice
(
idx
)
for
idx
in
batch_indices
):
# No batch indices, nothing to do
return
None
elem_with_batch_indices
=
elem
[
batch_indices
]
...
...
@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict
=
False
,
)
):
if
dim_idx
==
slice
(
None
):
if
is_full_slice
(
dim_idx
):
# Full slice can be safely applied to all inputs
continue
...
...
@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if
i
in
expanded_axes
:
if
isinstance
(
idx_item
,
slice
):
# Slice could be keeping or dropping this dimension
if
i
dx_item
==
slice
(
None
):
if
i
s_full_slice
(
idx_item
):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
...
...
@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
if
any
(
isinstance
(
index
,
slice
)
for
index
in
indices
):
if
any
(
isinstance
(
index
,
slice
)
or
isinstance
(
getattr
(
index
,
"type"
,
None
),
SliceType
)
for
index
in
indices
):
return
False
new_obj_arg
=
obj_arg
[
indices
]
...
...
@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
(
idx
,)
=
idxs
if
isinstance
(
idx
,
int
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
ps
.
ScalarType
|
TensorType
):
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
.
is_super
(
old_idx
)
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
Variable
):
if
isinstance
(
idx
,
int
|
np
.
integer
):
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
try
:
v
=
get_underlying_scalar_constant_value
(
...
...
@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
except
NotScalarConstantError
:
return
False
assert
idx_val
!=
np
.
newaxis
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
...
...
@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return
None
x
,
*
adv_index_vars
=
adv_subtensor
.
owner
.
inputs
adv_idxs
=
indices_from_subtensor
(
adv_index_vars
,
adv_subtensor
.
owner
.
op
.
idx_list
)
x
,
*
adv_idxs
=
adv_subtensor
.
owner
.
inputs
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if
(
not
all
(
if
any
(
(
(
isinstance
(
adv_idx
,
TensorVariable
)
and
adv_idx
.
type
.
dtype
!=
"bool"
)
or
(
isinstance
(
adv_idx
,
slice
)
and
adv_idx
==
slice
(
None
))
isinstance
(
adv_idx
.
type
,
NoneTypeT
)
or
(
isinstance
(
adv_idx
.
type
,
TensorType
)
and
adv_idx
.
type
.
dtype
==
"bool"
)
or
(
isinstance
(
adv_idx
.
type
,
SliceType
)
and
not
is_full_slice
(
adv_idx
))
)
for
adv_idx
in
adv_idxs
)
)
or
_non_consecutive_adv_indexing
(
adv_idxs
):
return
None
for
first_adv_idx_dim
,
adv_idx
in
enumerate
(
adv_idxs
):
# We already made sure there were only None slices besides integer indexes
if
isinstance
(
adv_idx
,
TensorVariabl
e
):
if
isinstance
(
adv_idx
.
type
,
TensorTyp
e
):
break
else
:
# no-break
# Not sure if this should ever happen, but better safe than sorry
...
...
@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_indexed
)
x_after_index_lift
=
expand_dims
(
x_indexed
,
dropped_dims
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
ndex_var
s
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
dx
s
)
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_after_adv_idx
)
new_out
=
squeeze
(
x_after_adv_idx
[
basic_idxs_kept
],
dropped_dims
)
...
...
pytensor/tensor/rewriting/uncanonicalize.py
浏览文件 @
cc6bed1a
...
...
@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from
pytensor.tensor.math
import
Min
,
neg
from
pytensor.tensor.rewriting.basic
import
register_uncanonicalize
from
pytensor.tensor.shape
import
Reshape
,
reshape
from
pytensor.tensor.subtensor
import
Subtensor
,
indices_from_subtensor
from
pytensor.tensor.subtensor
import
Subtensor
@register_uncanonicalize
...
...
@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
if
not
all
(
broadcastable
[
i
]
for
i
in
missing_dims
):
return
False
# create a new index tuple for a new Subtensor
# Reconstruct the full indices from the subtensor node, then replace
# dimensions that are being dropped by dimshuffle with scalar index 0
x
=
input_
.
owner
.
inputs
[
0
]
indices
=
list
(
indices_from_subtensor
(
input_
.
owner
.
inputs
[
1
:],
input_
.
owner
.
op
.
idx_list
)
)
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list
=
list
(
input_
.
owner
.
op
.
idx_list
)
new_inputs
=
[
input_
.
owner
.
inputs
[
0
]]
zero
=
constant
(
0
)
# Track which output dimension each index corresponds to
# Scalar indices remove dimensions, slices keep them
output_dim
=
0
for
i
,
idx
in
enumerate
(
indices
):
j
=
0
slice_i
=
-
1
subtensor_removed_dims
=
0
for
i
,
idx
in
enumerate
(
input_
.
owner
.
op
.
idx_list
):
if
isinstance
(
idx
,
slice
):
# This slice produces an output dimension
if
output_dim
in
missing_dims
:
#
This output dimension is being dropped, so replace slice with scalar
slice_i
+=
1
if
slice_i
in
missing_dims
:
#
Missing dim is a slice(None), remove by indexing by 0
if
idx
==
slice
(
None
):
indices
[
i
]
=
zero
new_idx_list
[
i
]
=
zero
new_inputs
+=
[
zero
]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else
:
# Use the start of the slice (or 0 if None)
indices
[
i
]
=
idx
.
start
if
idx
.
start
is
not
None
else
zero
output_dim
+=
1
# Scalar indices don't contribute to output dimensions
# Handle trailing dimensions that weren't explicitly indexed
for
input_dim
in
range
(
len
(
indices
),
x
.
ndim
):
if
output_dim
in
missing_dims
:
# This unindexed dimension is being dropped, index with 0
indices
.
append
(
zero
)
if
idx
.
start
is
None
:
start
=
zero
else
:
# This unindexed dimension is kept, index with slice(None)
indices
.
append
(
slice
(
None
))
output_dim
+=
1
start
=
input_
.
owner
.
inputs
[
1
+
j
]
j
+=
1
new_idx_list
[
i
]
=
start
new_inputs
+=
[
start
]
return
[
x
[
tuple
(
indices
)]]
# Ignore useless stop and step input if there is one
for
slice_attr
in
(
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
j
+=
1
# Keep non-dropped slice inputs
else
:
for
slice_attr
in
(
"start"
,
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
j
+=
1
# Keep non-dropped non-slice inputs
else
:
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
j
+=
1
subtensor_removed_dims
+=
1
# Verify the trailing dimensions the subtensor didn't look at.
for
idx
in
range
(
len
(
input_
.
owner
.
op
.
idx_list
),
new_inputs
[
0
]
.
ndim
):
if
(
idx
-
subtensor_removed_dims
)
in
missing_dims
:
while
len
(
new_idx_list
)
<
idx
:
new_idx_list
.
append
(
slice
(
None
))
new_idx_list
.
append
(
zero
)
new_inputs
.
append
(
zero
)
return
[
Subtensor
(
new_idx_list
)(
*
new_inputs
)]
return
False
pytensor/tensor/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/tensor/variable.py
浏览文件 @
cc6bed1a
...
...
@@ -15,8 +15,9 @@ from pytensor.scalar import (
ComplexError
,
)
from
pytensor.tensor
import
_get_vector_length
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
None
TypeT
from
pytensor.tensor.type_other
import
None
Const
from
pytensor.tensor.utils
import
hash_from_ndarray
...
...
@@ -454,14 +455,15 @@ class _tensor_py_operators:
elif
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
# Count the dimensions, check for bools and find ellipses.
ellipses
=
[]
index_dim_count
=
0
for
i
,
arg
in
enumerate
(
args
):
if
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
):
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
# no increase in index_dim_count
pass
elif
arg
is
Ellipsis
:
# no increase in index_dim_count
ellipses
.
append
(
i
)
elif
(
isinstance
(
arg
,
np
.
ndarray
|
Variable
)
...
...
@@ -503,41 +505,6 @@ class _tensor_py_operators:
self
.
ndim
-
index_dim_count
)
if
any
(
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
))
for
arg
in
args
):
expansion_axes
=
[]
new_args
=
[]
# Track dims consumed by args and inserted `None`s after ellipsis
counter
=
0
nones
=
0
for
arg
in
args
:
if
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
):
expansion_axes
.
append
(
counter
+
nones
)
# Expand here
nones
+=
1
new_args
.
append
(
slice
(
None
))
else
:
new_args
.
append
(
arg
)
consumed
=
1
if
hasattr
(
arg
,
"dtype"
)
and
arg
.
dtype
==
"bool"
:
consumed
=
arg
.
ndim
counter
+=
consumed
expanded
=
pt
.
expand_dims
(
self
,
expansion_axes
)
if
all
(
isinstance
(
arg
,
slice
)
and
arg
.
start
is
None
and
arg
.
stop
is
None
and
arg
.
step
is
None
for
arg
in
new_args
):
return
expanded
return
expanded
[
tuple
(
new_args
)]
def
is_empty_array
(
val
):
return
(
isinstance
(
val
,
tuple
|
list
)
and
len
(
val
)
==
0
)
or
(
isinstance
(
val
,
np
.
ndarray
)
and
val
.
size
==
0
...
...
@@ -553,16 +520,74 @@ class _tensor_py_operators:
for
inp
in
args
)
if
all
(
(
isinstance
(
arg
,
slice
|
int
|
float
|
np
.
number
)
or
(
hasattr
(
arg
,
"ndim"
)
and
arg
.
ndim
==
0
and
arg
.
dtype
!=
"bool"
)
)
for
arg
in
args
):
return
pt
.
subtensor
.
basic_subtensor
(
self
,
*
args
)
# Determine if advanced indexing is needed or not. The logic is
# already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used
advanced
=
False
for
i
,
arg
in
enumerate
(
args
):
if
includes_bool
(
arg
):
advanced
=
True
break
if
arg
is
not
np
.
newaxis
and
arg
is
not
NoneConst
:
try
:
pt
.
subtensor
.
index_vars_to_types
(
arg
)
except
AdvancedIndexingError
:
if
advanced
:
break
else
:
advanced
=
True
if
advanced
:
return
pt
.
subtensor
.
advanced_subtensor
(
self
,
*
args
)
else
:
if
np
.
newaxis
in
args
or
NoneConst
in
args
:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter
=
0
pattern
=
[]
new_args
=
[]
for
arg
in
args
:
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
pattern
.
append
(
"x"
)
new_args
.
append
(
slice
(
None
,
None
,
None
))
else
:
pattern
.
append
(
counter
)
counter
+=
1
new_args
.
append
(
arg
)
pattern
.
extend
(
list
(
range
(
counter
,
self
.
ndim
)))
view
=
self
.
dimshuffle
(
pattern
)
full_slices
=
True
for
arg
in
new_args
:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if
not
(
isinstance
(
arg
,
slice
)
and
(
arg
.
start
is
None
or
arg
.
start
is
NoneConst
)
and
(
arg
.
stop
is
None
or
arg
.
stop
is
NoneConst
)
and
(
arg
.
step
is
None
or
arg
.
step
is
NoneConst
)
):
full_slices
=
False
if
full_slices
:
return
view
else
:
return
view
.
__getitem__
(
tuple
(
new_args
))
else
:
return
pt
.
subtensor
.
Subtensor
(
args
)(
self
,
*
pt
.
subtensor
.
get_slice_elements
(
args
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
),
)
def
__setitem__
(
self
,
key
,
value
):
raise
TypeError
(
...
...
pytensor/xtensor/rewriting/indexing.py
浏览文件 @
cc6bed1a
...
...
@@ -2,10 +2,9 @@ from itertools import zip_longest
from
pytensor
import
as_symbolic
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.tensor
import
arange
,
specify_shape
from
pytensor.tensor
import
TensorType
,
arange
,
specify_shape
from
pytensor.tensor.subtensor
import
_non_consecutive_adv_indexing
,
inc_subtensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.indexing
import
Index
,
IndexUpdate
,
index
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
...
...
@@ -107,7 +106,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs
.
append
(
to_basic_idx
(
idx
)
)
aligned_idxs
.
append
(
idx
)
basic_idx_axis
.
append
(
out_dims
.
index
(
x_dim
))
else
:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
...
...
@@ -132,7 +131,7 @@ def _lower_index(node):
if
basic_idx_axis
:
aligned_idxs
=
[
idx
.
squeeze
(
axis
=
basic_idx_axis
)
if
(
isinstance
(
idx
,
TensorVariabl
e
)
and
idx
.
type
.
ndim
>
0
)
if
(
isinstance
(
idx
.
type
,
TensorTyp
e
)
and
idx
.
type
.
ndim
>
0
)
else
idx
for
idx
in
aligned_idxs
]
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
cc6bed1a
...
...
@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
from
pytensor.tensor.rewriting.basic
import
constant_folding
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
from
pytensor.tensor.type
import
matrix
,
values_eq_approx_always_true
,
vector
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceConstant
,
slicetype
from
tests.graph.utils
import
(
MyOp
,
MyType
,
...
...
@@ -627,6 +629,21 @@ def test_pre_constant_merge():
assert
res
==
[
o2
]
assert
o2
.
owner
.
inputs
[
2
]
is
c2
# What is this supposed to test?
ms
=
MakeSlice
()(
1
)
res
=
pre_constant_merge
(
empty_fgraph
,
[
ms
])
assert
res
==
[
ms
]
const_slice
=
SliceConstant
(
type
=
slicetype
,
data
=
slice
(
1
,
None
,
2
))
assert
isinstance
(
const_slice
,
Constant
)
adv
=
AdvancedSubtensor
()(
matrix
(),
[
2
,
3
],
const_slice
)
res
=
pre_constant_merge
(
empty_fgraph
,
adv
)
assert
res
==
[
adv
]
def
test_pre_greedy_node_rewriter
():
empty_fgraph
=
FunctionGraph
([],
[])
...
...
@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
# What exactly is this supposed to test?
ms
=
MakeSlice
()(
1
)
cst
=
pre_greedy_node_rewriter
(
empty_fgraph
,
[
constant_folding
],
ms
)
assert
isinstance
(
cst
,
SliceConstant
)
# Make sure constant of slice signature is hashable.
assert
isinstance
(
hash
(
cst
.
signature
()),
int
)
@pytest.mark.parametrize
(
"tracks"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"out_pattern"
,
[(
op2
,
"x"
),
"x"
,
1.0
])
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
compare_jax_and_py
([],
[
out_pt
],
[])
@pytest.mark.parametrize
(
"func"
,
(
pt_subtensor
.
advanced_inc_subtensor1
,
pt_subtensor
.
advanced_set_subtensor1
)
)
def
test_jax_AdvancedIncSubtensor1_runtime_broadcast
(
func
):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from
pytensor
import
function
y
=
pt
.
matrix
(
"y"
,
dtype
=
"float64"
,
shape
=
(
None
,
None
))
x
=
pt
.
zeros
((
10
,
5
))
idxs
=
np
.
repeat
(
np
.
arange
(
10
),
2
)
# 20 indices
out
=
func
(
x
,
y
,
idxs
)
f
=
function
([
y
],
out
,
mode
=
"JAX"
)
# Should work with correctly sized y
f
(
np
.
ones
((
20
,
5
)))
# Should raise for runtime broadcasting on first dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
1
,
5
)))
# Should raise for runtime broadcasting on second dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
20
,
1
)))
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
"""Setting or incrementing values with boolean indexing.
...
...
tests/link/mlx/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
compare_mlx_and_py
([],
[
out_pt
],
[])
@pytest.mark.xfail
(
reason
=
"MLX slice indices must be integers or None, dynamic slices not supported"
)
def
test_mlx_MakeSlice
():
"""Test MakeSlice operation."""
# Test slice creation
start
=
pt
.
iscalar
(
"start"
)
stop
=
pt
.
iscalar
(
"stop"
)
step
=
pt
.
iscalar
(
"step"
)
# Create a slice using MakeSlice
slice_op
=
pt_subtensor
.
MakeSlice
()
slice_pt
=
slice_op
(
start
,
stop
,
step
)
# Use simple constant array instead of arange
x_pt
=
pt
.
constant
(
np
.
arange
(
10
,
dtype
=
np
.
float32
))
out_pt
=
x_pt
[
slice_pt
]
compare_mlx_and_py
([
start
,
stop
,
step
],
[
out_pt
],
[
1
,
8
,
2
])
def
test_mlx_subtensor_edge_cases
():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -3,7 +3,9 @@ import contextlib
import
numpy
as
np
import
pytest
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
from
pytensor
import
Mode
,
as_symbolic
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
...
...
@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
inc_subtensor
,
set_subtensor
,
)
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
numba_inplace_mode
,
numba_mode
,
)
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_mode
rng
=
np
.
random
.
default_rng
(
sum
(
map
(
ord
,
"Numba subtensors"
)))
@pytest.mark.parametrize
(
"step"
,
[
None
,
1
,
2
,
-
2
,
"x"
],
ids
=
lambda
x
:
f
"step={x}"
)
@pytest.mark.parametrize
(
"stop"
,
[
None
,
10
,
"x"
],
ids
=
lambda
x
:
f
"stop={x}"
)
@pytest.mark.parametrize
(
"start"
,
[
None
,
0
,
3
,
"x"
],
ids
=
lambda
x
:
f
"start={x}"
)
def
test_slice
(
start
,
stop
,
step
):
x
=
ps
.
int64
(
"x"
)
sym_slice
=
as_symbolic
(
slice
(
x
if
start
==
"x"
else
start
,
x
if
stop
==
"x"
else
stop
,
x
if
step
==
"x"
else
step
,
)
)
no_opt_mode
=
Mode
(
linker
=
"numba"
,
optimizer
=
None
)
evaled_slice
=
sym_slice
.
eval
({
x
:
-
5
},
on_unused_input
=
"ignore"
,
mode
=
no_opt_mode
)
assert
isinstance
(
evaled_slice
,
slice
)
if
start
==
"x"
:
assert
evaled_slice
.
start
==
-
5
elif
start
is
None
and
(
evaled_slice
.
step
is
None
or
evaled_slice
.
step
>
0
):
# Numba can convert to 0 (and sometimes does) in this case
assert
evaled_slice
.
start
in
(
None
,
0
)
else
:
assert
evaled_slice
.
start
==
start
if
stop
==
"x"
:
assert
evaled_slice
.
stop
==
-
5
else
:
assert
evaled_slice
.
stop
==
stop
if
step
==
"x"
:
assert
evaled_slice
.
step
==
-
5
elif
step
is
None
:
# Numba can convert to 1 (and sometimes does) in this case
assert
evaled_slice
.
step
in
(
None
,
1
)
else
:
assert
evaled_slice
.
step
==
step
@pytest.mark.parametrize
(
"x, indices"
,
[
...
...
@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
slice
(
1
,
None
),
[[
0
,
0
],
[
0
,
0
]]),
),
# Newaxis with vector indexing
(
as_tensor
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
))),
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
),
],
)
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
...
...
@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False
,
False
,
),
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
)),
np
.
array
(
5
),
# Broadcasted scalar value
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
# Newaxis with vector indexing
False
,
False
,
),
],
)
@pytest.mark.parametrize
(
"inplace"
,
(
False
,
True
))
...
...
@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
inplace
,
):
# Need rewrite to support certain forms of advanced indexing without object mode
# Use inplace_mode when testing inplace operations to preserve inplace flag
base_mode
=
numba_inplace_mode
if
inplace
else
numba_mode
mode
=
base_mode
.
including
(
"specialize"
)
mode
=
numba_mode
.
including
(
"specialize"
)
x_pt
=
pt
.
as_tensor
(
x
)
.
type
(
"x"
)
y_pt
=
pt
.
as_tensor
(
y
)
.
type
(
"y"
)
...
...
@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig
=
x
.
copy
()
fn
(
x
,
y
)
assert
not
np
.
all
(
x
==
x_orig
)
def
test_advanced_indexing_with_newaxis_fallback_obj_mode
():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x
=
pt
.
matrix
(
"x"
)
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
.
inc
(
5
)
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedIncSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
cc6bed1a
...
...
@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
# Save original value to restore later
original_value
=
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
try
:
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
1
with
pytest
.
warns
(
FutureWarning
,
match
=
"tensor__insert_inplace_optimizer_validate_nb config is deprecated"
,
):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
finally
:
# Restore original value to avoid affecting other tests
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
original_value
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
cc6bed1a
...
...
@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
tensor4
,
vector
,
)
from
pytensor.tensor.type_other
import
make_slice
from
tests
import
unittest_tools
as
utt
from
tests.unittest_tools
import
create_pytensor_param
...
...
@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
# `AdvancedSubtensor`, two indices, one slice, convert
# `AdvancedSubtensor`, two indices, one s
ymbolic s
lice, convert
x
=
pt
.
matrix
(
"x"
)
indices
=
(
pt
.
as_tensor_variable
(
np
.
array
(
[
1
]
,
np
.
int64
)),
slice
(
None
,
10
),
pt
.
as_tensor_variable
(
np
.
array
(
1
,
np
.
int64
)),
make_slice
(
slice
(
None
,
10
)
),
)
z
=
x
[
indices
]
...
...
@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
z_fn
=
pytensor
.
function
([
x
],
z
,
mode
=
mode
)
subtensor_node
=
z_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
subtensor_node
.
op
,
(
AdvancedSubtensor
,
AdvancedSubtensor1
)
)
assert
isinstance
(
subtensor_node
.
op
,
AdvancedSubtensor
)
new_index
=
subtensor_node
.
inputs
[
1
]
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
...
...
@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
out
=
vectorize_graph
(
core_graph
,
replace
=
{
core_x
:
x
,
core_y
:
y
})
fn
,
ref_fn
=
self
.
compile_fn_and_ref
([
x
,
y
],
out
)
assert
self
.
has_blockwise
(
ref_fn
)
assert
not
self
.
has_blockwise
(
fn
)
test_x
=
np
.
ones
(
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
test_y
=
rng
.
integers
(
1
,
10
,
size
=
y
.
type
.
shape
,
dtype
=
y
.
type
.
dtype
)
np
.
testing
.
assert_allclose
(
fn
(
test_x
,
test_y
),
ref_fn
(
test_x
,
test_y
))
...
...
@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize
(
"basic_idx"
,
[
True
,
False
],
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids
=
[
"basic_idx"
,
"adv_idx"
],
)
@pytest.mark.parametrize
(
...
...
@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
core_idx
=
pt
.
tensor
(
"idx"
,
dtype
=
int
,
shape
=
()
if
basic_idx
else
(
2
,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with a new arange slice on the batched dimensions.
# once it is paired with a
n
new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out
=
core_a
[
0
,
:,
core_idx
]
.
set
(
core_v
)
...
...
tests/tensor/rewriting/test_subtensor_lift.py
浏览文件 @
cc6bed1a
...
...
@@ -32,6 +32,7 @@ from pytensor.tensor import (
lscalars
,
matrix
,
shape
,
slicetype
,
specify_shape
,
tensor
,
tensor3
,
...
...
@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
(
matrix
(),
(
iscalar
(),
iscalar
()),
(
slice
(
iscalar
(),
iscalar
(),
iscalar
()
),),
(
slice
type
(
),),
),
(
matrix
(),
...
...
@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
(
lambda
x
:
x
[:,
[
0
,
1
]][
0
],
True
),
(
lambda
x
:
x
[:,
[
0
,
1
],
[
0
,
0
]][
1
:],
True
),
(
lambda
x
:
x
[:,
[[
0
,
1
],
[
0
,
0
]]][
1
:],
True
),
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
True
),
# Not supported, basic indexing on advanced indexing dim
(
lambda
x
:
x
[[
0
,
1
]][
0
],
False
),
# Not
suppor
ted, basic indexing on the right of advanced indexing
# Not
implemen
ted, basic indexing on the right of advanced indexing
(
lambda
x
:
x
[[
0
,
1
]][:,
0
],
False
),
# Not implemented, complex flavors of advanced indexing
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
5
:,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
:,
np
.
array
([
True
,
False
,
False
])][
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
],
:,
[
0
,
1
]][:,
0
],
False
),
...
...
tests/tensor/test_blockwise.py
浏览文件 @
cc6bed1a
...
...
@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback
,
)
from
pytensor.tensor.nlinalg
import
MatrixInverse
,
eig
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.op
import
default_rng
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.slinalg
import
(
...
...
@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
def
test_vectorize_node_fallback_unsupported_type
():
rng
=
default_rng
(
)
node
=
normal
(
rng
=
rng
)
.
owner
x
=
tensor
(
"x"
,
shape
=
(
2
,
6
)
)
node
=
x
[:,
[
0
,
2
,
4
]]
.
owner
with
pytest
.
raises
(
NotImplementedError
,
match
=
re
.
escape
(
'Cannot vectorize node normal_rv{"(),()->()"}('
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
),
):
vectorize_node_fallback
(
node
.
op
,
node
,
*
node
.
inputs
)
vectorize_node_fallback
(
node
.
op
,
node
,
node
.
inputs
)
def
check_blockwise_runtime_broadcasting
(
mode
):
...
...
tests/tensor/test_subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
tests/tensor/test_type_other.py
浏览文件 @
cc6bed1a
...
...
@@ -4,8 +4,30 @@ import pytensor
from
pytensor
import
as_symbolic
from
pytensor.graph.basic
import
Constant
from
pytensor.tensor.math
import
argmax
from
pytensor.tensor.type
import
vector
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type_other
import
(
MakeSlice
,
NoneConst
,
NoneTypeT
,
SliceConstant
,
SliceType
,
make_slice
,
)
def
test_SliceType
():
st
=
SliceType
()
assert
st
==
st
.
clone
()
def
test_make_slice_merge
():
# In the past, this was crahsing during compilation.
i
=
iscalar
()
s1
=
make_slice
(
0
,
i
)
s2
=
make_slice
(
0
,
i
)
f
=
pytensor
.
function
([
i
],
[
s1
,
s2
])
nodes
=
f
.
maker
.
fgraph
.
apply_nodes
assert
len
([
n
for
n
in
nodes
if
isinstance
(
n
.
op
,
MakeSlice
)])
==
1
def
test_none_Constant
():
...
...
@@ -25,6 +47,8 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past.
import
pickle
import
pytensor
x
=
vector
(
"x"
)
y
=
argmax
(
x
)
kwargs
=
{}
...
...
@@ -36,18 +60,11 @@ def test_none_Constant():
def
test_as_symbolic
():
# Remove this when xtensor is not using symbolic slices
from
pytensor.tensor.type
import
iscalar
from
pytensor.tensor.type_other
import
SliceConstant
,
slicetype
res
=
as_symbolic
(
None
)
assert
res
is
NoneConst
res
=
as_symbolic
(
slice
(
iscalar
()))
assert
res
.
owner
.
op
==
make_slice
res
=
as_symbolic
(
slice
(
1
,
2
))
assert
isinstance
(
res
,
SliceConstant
)
assert
res
.
type
==
slicetype
assert
res
.
data
==
slice
(
1
,
2
)
i
=
iscalar
()
res
=
as_symbolic
(
slice
(
i
))
assert
res
.
owner
is
not
None
tests/tensor/test_variable.py
浏览文件 @
cc6bed1a
...
...
@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar
,
tensor3
,
)
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneConst
from
pytensor.tensor.variable
import
(
DenseTensorConstant
,
DenseTensorVariable
,
...
...
@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z
=
x
[:,
i
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlice
,
AdvancedSubtensor
]
z
=
x
[
...
,
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
DimShuffl
e
,
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlic
e
,
AdvancedSubtensor
]
z
=
x
[
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论