Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc6bed1a
提交
cc6bed1a
authored
2月 28, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
3月 01, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Refactor AdvancedSubtensor"
This reverts commit
db7fa079
.
上级
03afa5bb
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
26 个修改的文件
包含
448 行增加
和
244 行删除
+448
-244
destroyhandler.py
pytensor/graph/destroyhandler.py
+2
-2
subtensor.py
pytensor/link/jax/dispatch/subtensor.py
+33
-3
subtensor.py
pytensor/link/mlx/dispatch/subtensor.py
+18
-5
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+0
-0
subtensor.py
pytensor/link/pytorch/dispatch/subtensor.py
+19
-10
rewriting.py
pytensor/scan/rewriting.py
+24
-4
basic.py
pytensor/tensor/basic.py
+4
-4
basic.py
pytensor/tensor/random/rewriting/basic.py
+17
-9
shape.py
pytensor/tensor/rewriting/shape.py
+3
-5
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+0
-0
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+29
-21
uncanonicalize.py
pytensor/tensor/rewriting/uncanonicalize.py
+52
-34
subtensor.py
pytensor/tensor/subtensor.py
+0
-0
variable.py
pytensor/tensor/variable.py
+73
-48
indexing.py
pytensor/xtensor/rewriting/indexing.py
+3
-4
test_basic.py
tests/graph/rewriting/test_basic.py
+26
-0
test_subtensor.py
tests/link/jax/test_subtensor.py
+0
-31
test_subtensor.py
tests/link/mlx/test_subtensor.py
+21
-0
test_subtensor.py
tests/link/numba/test_subtensor.py
+62
-20
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+6
-12
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+16
-6
test_subtensor_lift.py
tests/tensor/rewriting/test_subtensor_lift.py
+4
-3
test_blockwise.py
tests/tensor/test_blockwise.py
+4
-8
test_subtensor.py
tests/tensor/test_subtensor.py
+0
-0
test_type_other.py
tests/tensor/test_type_other.py
+29
-12
test_variable.py
tests/tensor/test_variable.py
+3
-3
没有找到文件。
pytensor/graph/destroyhandler.py
浏览文件 @
cc6bed1a
...
@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
...
@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
}
}
tolerated
.
add
(
destroyed_idx
)
tolerated
.
add
(
destroyed_idx
)
tolerate_aliased
=
getattr
(
tolerate_aliased
=
getattr
(
app
.
op
,
"destroyhandler_tolerate_aliased"
,
()
app
.
op
,
"destroyhandler_tolerate_aliased"
,
[]
)
)
assert
isinstance
(
tolerate_aliased
,
tuple
|
list
)
assert
isinstance
(
tolerate_aliased
,
list
)
ignored
=
{
ignored
=
{
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
idx1
for
idx0
,
idx1
in
tolerate_aliased
if
idx0
==
destroyed_idx
}
}
...
...
pytensor/link/jax/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
...
@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
Subtensor
,
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type_other
import
MakeSlice
BOOLEAN_MASK_ERROR
=
"""JAX does not support resizing arrays with boolean
BOOLEAN_MASK_ERROR
=
"""JAX does not support resizing arrays with boolean
...
@@ -34,8 +35,10 @@ slice length.
...
@@ -34,8 +35,10 @@ slice length.
@jax_funcify.register
(
AdvancedSubtensor
)
@jax_funcify.register
(
AdvancedSubtensor
)
@jax_funcify.register
(
AdvancedSubtensor1
)
@jax_funcify.register
(
AdvancedSubtensor1
)
def
jax_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
def
jax_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
indices
=
indices
[
0
]
...
@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
...
@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register
(
IncSubtensor
)
@jax_funcify.register
(
IncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor
)
@jax_funcify.register
(
AdvancedIncSubtensor1
)
@jax_funcify.register
(
AdvancedIncSubtensor1
)
def
jax_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
def
jax_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
def
jax_fn
(
x
,
indices
,
y
):
...
@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
...
@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def
jax_fn
(
x
,
indices
,
y
):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
return
x
.
at
[
indices
]
.
add
(
y
)
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
indices
=
indices
[
0
]
...
@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
...
@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return
jax_fn
(
x
,
indices
,
y
)
return
jax_fn
(
x
,
indices
,
y
)
return
incsubtensor
return
incsubtensor
@jax_funcify.register
(
AdvancedIncSubtensor
)
def
jax_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
set
(
y
)
else
:
def
jax_fn
(
x
,
indices
,
y
):
return
x
.
at
[
indices
]
.
add
(
y
)
def
advancedincsubtensor
(
x
,
y
,
*
ilist
,
jax_fn
=
jax_fn
):
return
jax_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
@jax_funcify.register
(
MakeSlice
)
def
jax_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/mlx/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
...
@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
Subtensor
,
Subtensor
,
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type_other
import
MakeSlice
@mlx_funcify.register
(
Subtensor
)
@mlx_funcify.register
(
Subtensor
)
def
mlx_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
def
mlx_funcify_Subtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
subtensor
(
x
,
*
ilists
):
def
subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
indices
=
indices_from_subtensor
([
int
(
element
)
for
element
in
ilists
],
idx_list
)
[
int
(
element
)
for
element
in
ilists
],
op
.
idx_list
)
if
len
(
indices
)
==
1
:
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
indices
=
indices
[
0
]
...
@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
...
@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register
(
AdvancedSubtensor
)
@mlx_funcify.register
(
AdvancedSubtensor
)
@mlx_funcify.register
(
AdvancedSubtensor1
)
@mlx_funcify.register
(
AdvancedSubtensor1
)
def
mlx_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
def
mlx_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
def
advanced_subtensor
(
x
,
*
ilists
):
def
advanced_subtensor
(
x
,
*
ilists
):
indices
=
indices_from_subtensor
(
ilists
,
op
.
idx_list
)
indices
=
indices_from_subtensor
(
ilists
,
idx_list
)
if
len
(
indices
)
==
1
:
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
indices
=
indices
[
0
]
...
@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
...
@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register
(
IncSubtensor
)
@mlx_funcify.register
(
IncSubtensor
)
@mlx_funcify.register
(
AdvancedIncSubtensor1
)
@mlx_funcify.register
(
AdvancedIncSubtensor1
)
def
mlx_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
def
mlx_funcify_IncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
getattr
(
op
,
"idx_list"
,
None
)
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
if
getattr
(
op
,
"set_instead_of_inc"
,
False
):
def
mlx_fn
(
x
,
indices
,
y
):
def
mlx_fn
(
x
,
indices
,
y
):
...
@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
...
@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x
[
indices
]
+=
y
x
[
indices
]
+=
y
return
x
return
x
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
op
.
idx_list
):
def
incsubtensor
(
x
,
y
,
*
ilist
,
mlx_fn
=
mlx_fn
,
idx_list
=
idx_list
):
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
indices
=
indices_from_subtensor
(
ilist
,
idx_list
)
if
len
(
indices
)
==
1
:
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
indices
=
indices
[
0
]
...
@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
...
@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
mlx_fn
(
x
,
ilist
,
y
)
return
mlx_fn
(
x
,
ilist
,
y
)
return
advancedincsubtensor
return
advancedincsubtensor
@mlx_funcify.register
(
MakeSlice
)
def
mlx_funcify_MakeSlice
(
op
,
**
kwargs
):
def
makeslice
(
*
x
):
return
slice
(
*
x
)
return
makeslice
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/link/pytorch/dispatch/subtensor.py
浏览文件 @
cc6bed1a
...
@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
...
@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
Subtensor
,
Subtensor
,
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceType
def
check_negative_steps
(
indices
):
def
check_negative_steps
(
indices
):
...
@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
...
@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return
subtensor
return
subtensor
@pytorch_funcify.register
(
MakeSlice
)
def
pytorch_funcify_makeslice
(
op
,
**
kwargs
):
def
makeslice
(
start
,
stop
,
step
):
# Torch does not like numpy integers in indexing slices
return
slice
(
None
if
start
is
None
else
int
(
start
),
None
if
stop
is
None
else
int
(
stop
),
None
if
step
is
None
else
int
(
step
),
)
return
makeslice
@pytorch_funcify.register
(
AdvancedSubtensor1
)
@pytorch_funcify.register
(
AdvancedSubtensor1
)
@pytorch_funcify.register
(
AdvancedSubtensor
)
@pytorch_funcify.register
(
AdvancedSubtensor
)
def
pytorch_funcify_AdvSubtensor
(
op
,
node
,
**
kwargs
):
def
pytorch_funcify_AdvSubtensor
(
op
,
node
,
**
kwargs
):
def
advsubtensor
(
x
,
*
indices
):
def
advsubtensor
(
x
,
*
indices
):
indices
=
indices_from_subtensor
(
indices
,
op
.
idx_list
)
check_negative_steps
(
indices
)
check_negative_steps
(
indices
)
return
x
[
indices
]
return
x
[
indices
]
...
@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
...
@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register
(
AdvancedIncSubtensor
)
@pytorch_funcify.register
(
AdvancedIncSubtensor
)
@pytorch_funcify.register
(
AdvancedIncSubtensor1
)
@pytorch_funcify.register
(
AdvancedIncSubtensor1
)
def
pytorch_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
def
pytorch_funcify_AdvancedIncSubtensor
(
op
,
node
,
**
kwargs
):
idx_list
=
op
.
idx_list
inplace
=
op
.
inplace
inplace
=
op
.
inplace
ignore_duplicates
=
getattr
(
op
,
"ignore_duplicates"
,
False
)
ignore_duplicates
=
getattr
(
op
,
"ignore_duplicates"
,
False
)
if
op
.
set_instead_of_inc
:
if
op
.
set_instead_of_inc
:
def
adv_set_subtensor
(
x
,
y
,
*
flattened_indices
):
def
adv_set_subtensor
(
x
,
y
,
*
indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
check_negative_steps
(
indices
)
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
...
@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif
ignore_duplicates
:
elif
ignore_duplicates
:
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
flattened_indices
):
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
check_negative_steps
(
indices
)
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
...
@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
...
@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return
adv_inc_subtensor_no_duplicates
return
adv_inc_subtensor_no_duplicates
else
:
else
:
if
any
(
isinstance
(
entry
,
slice
)
for
entry
in
idx_list
):
if
any
(
isinstance
(
idx
.
type
,
SliceType
)
for
idx
in
node
.
inputs
[
2
:]
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
)
def
adv_inc_subtensor
(
x
,
y
,
*
flattened_indices
):
def
adv_inc_subtensor
(
x
,
y
,
*
indices
):
indices
=
indices_from_subtensor
(
flattened_indices
,
idx_list
)
# Not needed because slices aren't supported
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
# check_negative_steps(indices)
if
not
inplace
:
if
not
inplace
:
x
=
x
.
clone
()
x
=
x
.
clone
()
...
...
pytensor/scan/rewriting.py
浏览文件 @
cc6bed1a
...
@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
...
@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
IncSubtensor
,
IncSubtensor
,
Subtensor
,
Subtensor
,
basic_subtensor
,
get_canonical_form_slice
,
get_canonical_form_slice
,
get_idx_list
,
get_idx_list
,
get_slice_elements
,
set_subtensor
,
set_subtensor
,
)
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
...
@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if
not
(
if
not
(
isinstance
(
op
,
IncSubtensor
)
isinstance
(
op
,
IncSubtensor
)
and
op
.
set_instead_of_inc
and
op
.
set_instead_of_inc
and
op
.
idx_list
==
(
slice
(
None
,
0
),)
and
op
.
idx_list
==
[
slice
(
None
,
ps
.
int64
)]
):
):
return
False
return
False
...
@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else
:
else
:
# 2.3.1 extract idx list of subtensor
# 2.3.1 extract idx list of subtensor
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps
=
None
slices
[
i
]
=
None
break
# 2.3.2 extract the begin/end of the first dimension
# 2.3.2 extract the begin/end of the first dimension
if
i
>=
op_info
.
n_mit_mot
:
if
i
>=
op_info
.
n_mit_mot
:
...
@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break
break
else
:
else
:
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
this_slice
=
get_idx_list
(
cl
.
inputs
,
cl
.
op
.
idx_list
)
if
this_slice
is
None
:
store_steps
[
i
]
=
0
break
if
isinstance
(
this_slice
[
0
],
slice
):
if
isinstance
(
this_slice
[
0
],
slice
):
start
=
this_slice
[
0
]
.
start
start
=
this_slice
[
0
]
.
start
...
@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
)
else
:
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,
*
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
nw_pos
=
inv_compress_map
[
idx
]
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
fslice
,
*
old_slices
[
1
:])
subtens
=
Subtensor
(
nw_slice
)
# slice inputs
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
replaced_outs
.
append
(
idx
)
replaced_outs
.
append
(
idx
)
...
@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
...
@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
)
nw_slice
=
(
sanitize
(
position
),
*
old_slices
[
1
:])
nw_slice
=
(
sanitize
(
position
),
*
old_slices
[
1
:])
new_o
=
basic_subtensor
(
new_outs
[
nw_pos
],
*
nw_slice
)
subtens
=
Subtensor
(
nw_slice
)
sl_ins
=
get_slice_elements
(
nw_slice
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
)
new_o
=
cast
(
TensorVariable
,
subtens
(
new_outs
[
nw_pos
],
*
sl_ins
))
if
new_o
.
ndim
>
0
:
if
new_o
.
ndim
>
0
:
new_o
=
new_o
[::
cnf_slice
[
1
]]
new_o
=
new_o
[::
cnf_slice
[
1
]]
old_new
+=
[(
old
,
new_o
)]
old_new
+=
[(
old
,
new_o
)]
...
...
pytensor/tensor/basic.py
浏览文件 @
cc6bed1a
...
@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
...
@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.type
import
HasShape
,
Type
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.op
import
COp
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.link.c.params_type
import
ParamsType
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
from
pytensor.printing
import
Printer
,
min_informative_str
,
pprint
,
set_precedence
...
@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
...
@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
var
.
ndim
==
1
for
var
in
v
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
:]
):
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
...
@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and
len
(
v
.
owner
.
op
.
idx_list
)
==
1
and
len
(
v
.
owner
.
op
.
idx_list
)
==
1
):
):
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
idx
=
v
.
owner
.
op
.
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
idx
=
_get_underlying_scalar_constant_value
(
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
v
.
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
...
@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op
=
owner
.
op
op
=
owner
.
op
idx_list
=
op
.
idx_list
idx_list
=
op
.
idx_list
idx
=
idx_list
[
0
]
idx
=
idx_list
[
0
]
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
Type
):
idx
=
_get_underlying_scalar_constant_value
(
idx
=
_get_underlying_scalar_constant_value
(
owner
.
inputs
[
1
],
max_recur
=
max_recur
owner
.
inputs
[
1
],
max_recur
=
max_recur
)
)
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
cc6bed1a
...
@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
...
@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type_other
import
NoneTypeT
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
def
is_rv_used_in_graph
(
base_rv
,
node
,
fgraph
):
...
@@ -237,15 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -237,15 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
return
False
return
False
# Parse indices
# Parse indices
if
isinstance
(
subtensor_op
,
Subtensor
|
AdvancedSubtensor
):
if
isinstance
(
subtensor_op
,
Subtensor
):
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
subtensor_op
.
idx_list
)
indices
=
indices_from_subtensor
(
node
.
inputs
[
1
:],
subtensor_op
.
idx_list
)
else
:
else
:
indices
=
node
.
inputs
[
1
:]
indices
=
node
.
inputs
[
1
:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
for
idx
in
indices
):
# and make use of the dimshuffle lift rewrite
return
False
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if
any
(
is_nd_advanced_idx
(
idx
,
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
indices
):
return
False
# Check that indexing does not act on support dims
# Check that indexing does not act on support dims
batch_ndims
=
rv_op
.
batch_ndim
(
rv_node
)
batch_ndims
=
rv_op
.
batch_ndim
(
rv_node
)
...
@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices
[
batch_ndims
:],
non_bool_indices
[
batch_ndims
:],
)
)
for
idx
in
supp_indices
:
for
idx
in
supp_indices
:
if
idx
!=
slice
(
None
):
if
not
(
isinstance
(
idx
.
type
,
SliceType
)
and
all
(
isinstance
(
i
.
type
,
NoneTypeT
)
for
i
in
idx
.
owner
.
inputs
)
):
return
False
return
False
n_discarded_idxs
=
len
(
supp_indices
)
n_discarded_idxs
=
len
(
supp_indices
)
indices
=
indices
[:
-
n_discarded_idxs
]
indices
=
indices
[:
-
n_discarded_idxs
]
...
@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
...
@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim
# Broadcasted dim
if
curr_dim
in
bcast_param_dims
:
if
curr_dim
in
bcast_param_dims
:
# Slice indexing, keep degenerate dim by none-slicing
# Slice indexing, keep degenerate dim by none-slicing
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
)
or
isinstance
(
idx
.
type
,
SliceType
)
:
batch_indices
.
append
(
slice
(
None
))
batch_indices
.
append
(
slice
(
None
))
# Integer indexing, drop degenerate dim by 0-indexing
# Integer indexing, drop degenerate dim by 0-indexing
else
:
else
:
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
cc6bed1a
...
@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
...
@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
)
)
from
pytensor.graph.traversal
import
ancestors
from
pytensor.graph.traversal
import
ancestors
from
pytensor.graph.utils
import
InconsistencyError
,
get_variable_trace_string
from
pytensor.graph.utils
import
InconsistencyError
,
get_variable_trace_string
from
pytensor.scalar
import
ScalarType
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
MakeVector
,
MakeVector
,
as_tensor_variable
,
as_tensor_variable
,
...
@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
...
@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
if
isinstance
(
var
.
owner
.
op
,
Shape_i
):
if
isinstance
(
var
.
owner
.
op
,
Shape_i
):
return
(
var
.
owner
.
op
.
i
==
i
)
and
(
var
.
owner
.
inputs
[
0
]
==
x
)
# type: ignore
return
(
var
.
owner
.
op
.
i
==
i
)
and
(
var
.
owner
.
inputs
[
0
]
==
x
)
# type: ignore
# Match Subtensor((
int,))(Shape(input), i) - single integer index into shape
# Match Subtensor((
ScalarType,))(Shape(input), i)
if
isinstance
(
var
.
owner
.
op
,
Subtensor
):
if
isinstance
(
var
.
owner
.
op
,
Subtensor
):
idx_entry
=
(
var
.
owner
.
op
.
idx_list
[
0
]
if
len
(
var
.
owner
.
op
.
idx_list
)
==
1
else
None
)
return
(
return
(
# Check we have integer indexing operation
# Check we have integer indexing operation
# (and not slice or multiple indexing)
# (and not slice or multiple indexing)
len
(
var
.
owner
.
op
.
idx_list
)
==
1
len
(
var
.
owner
.
op
.
idx_list
)
==
1
and
isinstance
(
idx_entry
,
int
)
and
isinstance
(
var
.
owner
.
op
.
idx_list
[
0
],
ScalarType
)
# Check we are indexing on the shape of x
# Check we are indexing on the shape of x
and
var
.
owner
.
inputs
[
0
]
.
owner
is
not
None
and
var
.
owner
.
inputs
[
0
]
.
owner
is
not
None
and
isinstance
(
var
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
and
isinstance
(
var
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Shape
)
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/tensor/rewriting/subtensor_lift.py
浏览文件 @
cc6bed1a
...
@@ -8,6 +8,7 @@ from pytensor import Variable
...
@@ -8,6 +8,7 @@ from pytensor import Variable
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
,
FunctionGraph
,
node_rewriter
,
vectorize_graph
from
pytensor.graph
import
Constant
,
FunctionGraph
,
node_rewriter
,
vectorize_graph
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
copy_stack_trace
from
pytensor.graph.rewriting.basic
import
NodeRewriter
,
copy_stack_trace
from
pytensor.scalar
import
basic
as
ps
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
Join
,
Join
,
...
@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
...
@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize
,
register_stabilize
,
)
)
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.elemwise
import
local_dimshuffle_lift
from
pytensor.tensor.rewriting.subtensor
import
register_useless
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
,
register_useless
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
Shape
,
Shape
,
SpecifyShape
,
SpecifyShape
,
...
@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
...
@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
...
@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
)
->
bool
:
)
->
bool
:
if
isinstance
(
axis
,
int
):
if
isinstance
(
axis
,
int
):
axis
=
(
axis
,)
axis
=
(
axis
,)
return
any
(
ax
<
len
(
idxs
)
and
not
i
dxs
[
ax
]
==
slice
(
None
)
for
ax
in
axis
)
return
any
(
ax
<
len
(
idxs
)
and
not
i
s_full_slice
(
idxs
[
ax
]
)
for
ax
in
axis
)
def
_lift_subtensor_non_axis
(
def
_lift_subtensor_non_axis
(
...
@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
...
@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable
:
TensorVariable
,
old_subtensor_variable
:
TensorVariable
,
)
->
None
|
list
[
TensorVariable
]:
)
->
None
|
list
[
TensorVariable
]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
dx
==
slice
(
None
)]
real_indices
=
[
idx
for
idx
in
idx_tuple
if
not
i
s_full_slice
(
idx
)]
if
len
(
real_indices
)
>
1
and
variable
.
type
.
ndim
>
1
:
if
len
(
real_indices
)
>
1
and
variable
.
type
.
ndim
>
1
:
# Split the subtensor
# Split the subtensor
idx_to_keep
=
idx_tuple
[
axis
]
idx_to_keep
=
idx_tuple
[
axis
]
...
@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
...
@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if
len
(
idx_tuple
)
>
batch_ndim
:
if
len
(
idx_tuple
)
>
batch_ndim
:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices
,
core_indices
=
idx_tuple
[:
batch_ndim
],
idx_tuple
[
batch_ndim
:]
batch_indices
,
core_indices
=
idx_tuple
[:
batch_ndim
],
idx_tuple
[
batch_ndim
:]
if
all
(
i
dx
==
slice
(
None
)
for
idx
in
batch_indices
):
if
all
(
i
s_full_slice
(
idx
)
for
idx
in
batch_indices
):
# No batch indices, nothing to do
# No batch indices, nothing to do
return
None
return
None
elem_with_batch_indices
=
elem
[
batch_indices
]
elem_with_batch_indices
=
elem
[
batch_indices
]
...
@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
...
@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict
=
False
,
strict
=
False
,
)
)
):
):
if
dim_idx
==
slice
(
None
):
if
is_full_slice
(
dim_idx
):
# Full slice can be safely applied to all inputs
# Full slice can be safely applied to all inputs
continue
continue
...
@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
...
@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if
i
in
expanded_axes
:
if
i
in
expanded_axes
:
if
isinstance
(
idx_item
,
slice
):
if
isinstance
(
idx_item
,
slice
):
# Slice could be keeping or dropping this dimension
# Slice could be keeping or dropping this dimension
if
i
dx_item
==
slice
(
None
):
if
i
s_full_slice
(
idx_item
):
# A None slice, always keeps the dimension.
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
# We skip the index, and later introduce the needed expand_dim
continue
continue
...
@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
...
@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
indices
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
if
any
(
isinstance
(
index
,
slice
)
for
index
in
indices
):
if
any
(
isinstance
(
index
,
slice
)
or
isinstance
(
getattr
(
index
,
"type"
,
None
),
SliceType
)
for
index
in
indices
):
return
False
return
False
new_obj_arg
=
obj_arg
[
indices
]
new_obj_arg
=
obj_arg
[
indices
]
...
@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
...
@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
(
idx
,)
=
idxs
(
idx
,)
=
idxs
if
isinstance
(
idx
,
int
):
if
isinstance
(
idx
,
ps
.
ScalarType
|
TensorType
):
idx
=
node
.
inputs
[
1
]
old_idx
,
idx
=
idx
,
node
.
inputs
[
1
]
assert
idx
.
type
.
is_super
(
old_idx
)
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
elif
isinstance
(
node
.
op
,
AdvancedSubtensor1
):
idx
=
node
.
inputs
[
1
]
idx
=
node
.
inputs
[
1
]
if
isinstance
(
idx
,
Variable
):
if
isinstance
(
idx
,
int
|
np
.
integer
):
return
[
x
.
owner
.
inputs
[
idx
]]
elif
isinstance
(
idx
,
Variable
):
if
idx
.
ndim
==
0
:
if
idx
.
ndim
==
0
:
try
:
try
:
v
=
get_underlying_scalar_constant_value
(
v
=
get_underlying_scalar_constant_value
(
...
@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
...
@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
except
NotScalarConstantError
:
except
NotScalarConstantError
:
return
False
return
False
assert
idx_val
!=
np
.
newaxis
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
if
not
isinstance
(
shape_arg
.
type
,
TensorType
):
return
False
return
False
...
@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
...
@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return
None
return
None
x
,
*
adv_index_vars
=
adv_subtensor
.
owner
.
inputs
x
,
*
adv_idxs
=
adv_subtensor
.
owner
.
inputs
adv_idxs
=
indices_from_subtensor
(
adv_index_vars
,
adv_subtensor
.
owner
.
op
.
idx_list
)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if
(
if
any
(
not
all
(
(
(
isinstance
(
adv_idx
.
type
,
NoneTypeT
)
(
isinstance
(
adv_idx
,
TensorVariable
)
and
adv_idx
.
type
.
dtype
!=
"bool"
)
or
(
isinstance
(
adv_idx
.
type
,
TensorType
)
and
adv_idx
.
type
.
dtype
==
"bool"
)
or
(
isinstance
(
adv_idx
,
slice
)
and
adv_idx
==
slice
(
None
))
or
(
isinstance
(
adv_idx
.
type
,
SliceType
)
and
not
is_full_slice
(
adv_idx
))
)
for
adv_idx
in
adv_idxs
)
)
for
adv_idx
in
adv_idxs
)
or
_non_consecutive_adv_indexing
(
adv_idxs
):
)
or
_non_consecutive_adv_indexing
(
adv_idxs
):
return
None
return
None
for
first_adv_idx_dim
,
adv_idx
in
enumerate
(
adv_idxs
):
for
first_adv_idx_dim
,
adv_idx
in
enumerate
(
adv_idxs
):
# We already made sure there were only None slices besides integer indexes
# We already made sure there were only None slices besides integer indexes
if
isinstance
(
adv_idx
,
TensorVariabl
e
):
if
isinstance
(
adv_idx
.
type
,
TensorTyp
e
):
break
break
else
:
# no-break
else
:
# no-break
# Not sure if this should ever happen, but better safe than sorry
# Not sure if this should ever happen, but better safe than sorry
...
@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
...
@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_indexed
)
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_indexed
)
x_after_index_lift
=
expand_dims
(
x_indexed
,
dropped_dims
)
x_after_index_lift
=
expand_dims
(
x_indexed
,
dropped_dims
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
ndex_var
s
)
x_after_adv_idx
=
adv_subtensor
.
owner
.
op
(
x_after_index_lift
,
*
adv_i
dx
s
)
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_after_adv_idx
)
copy_stack_trace
([
basic_subtensor
,
adv_subtensor
],
x_after_adv_idx
)
new_out
=
squeeze
(
x_after_adv_idx
[
basic_idxs_kept
],
dropped_dims
)
new_out
=
squeeze
(
x_after_adv_idx
[
basic_idxs_kept
],
dropped_dims
)
...
...
pytensor/tensor/rewriting/uncanonicalize.py
浏览文件 @
cc6bed1a
...
@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
...
@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from
pytensor.tensor.math
import
Min
,
neg
from
pytensor.tensor.math
import
Min
,
neg
from
pytensor.tensor.rewriting.basic
import
register_uncanonicalize
from
pytensor.tensor.rewriting.basic
import
register_uncanonicalize
from
pytensor.tensor.shape
import
Reshape
,
reshape
from
pytensor.tensor.shape
import
Reshape
,
reshape
from
pytensor.tensor.subtensor
import
Subtensor
,
indices_from_subtensor
from
pytensor.tensor.subtensor
import
Subtensor
@register_uncanonicalize
@register_uncanonicalize
...
@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
...
@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
if
not
all
(
broadcastable
[
i
]
for
i
in
missing_dims
):
if
not
all
(
broadcastable
[
i
]
for
i
in
missing_dims
):
return
False
return
False
# create a new index tuple for a new Subtensor
# create a new idx_list for a new Subtensor object
# Reconstruct the full indices from the subtensor node, then replace
# have to loop on idx_list and inputs
# dimensions that are being dropped by dimshuffle with scalar index 0
# inputs has the length of sum of non None elements of idx_list
x
=
input_
.
owner
.
inputs
[
0
]
# (check in slice!).
indices
=
list
(
# len(missing_dims) can be < len(idx_list), this happens if
indices_from_subtensor
(
# tensor was indexed such as x[scalar, :, :], check that as well
input_
.
owner
.
inputs
[
1
:],
input_
.
owner
.
op
.
idx_list
new_idx_list
=
list
(
input_
.
owner
.
op
.
idx_list
)
)
new_inputs
=
[
input_
.
owner
.
inputs
[
0
]]
)
zero
=
constant
(
0
)
zero
=
constant
(
0
)
j
=
0
# Track which output dimension each index corresponds to
slice_i
=
-
1
# Scalar indices remove dimensions, slices keep them
subtensor_removed_dims
=
0
output_dim
=
0
for
i
,
idx
in
enumerate
(
input_
.
owner
.
op
.
idx_list
):
for
i
,
idx
in
enumerate
(
indices
):
if
isinstance
(
idx
,
slice
):
if
isinstance
(
idx
,
slice
):
# This slice produces an output dimension
slice_i
+=
1
if
output_dim
in
missing_dims
:
if
slice_i
in
missing_dims
:
#
This output dimension is being dropped, so replace slice with scalar
#
Missing dim is a slice(None), remove by indexing by 0
if
idx
==
slice
(
None
):
if
idx
==
slice
(
None
):
indices
[
i
]
=
zero
new_idx_list
[
i
]
=
zero
new_inputs
+=
[
zero
]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else
:
else
:
# Use the start of the slice (or 0 if None)
if
idx
.
start
is
None
:
indices
[
i
]
=
idx
.
start
if
idx
.
start
is
not
None
else
zero
start
=
zero
output_dim
+=
1
else
:
# Scalar indices don't contribute to output dimensions
start
=
input_
.
owner
.
inputs
[
1
+
j
]
j
+=
1
# Handle trailing dimensions that weren't explicitly indexed
new_idx_list
[
i
]
=
start
for
input_dim
in
range
(
len
(
indices
),
x
.
ndim
):
new_inputs
+=
[
start
]
if
output_dim
in
missing_dims
:
# This unindexed dimension is being dropped, index with 0
# Ignore useless stop and step input if there is one
indices
.
append
(
zero
)
for
slice_attr
in
(
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
j
+=
1
# Keep non-dropped slice inputs
else
:
for
slice_attr
in
(
"start"
,
"stop"
,
"step"
):
if
getattr
(
idx
,
slice_attr
)
is
not
None
:
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
j
+=
1
# Keep non-dropped non-slice inputs
else
:
else
:
# This unindexed dimension is kept, index with slice(None)
new_inputs
+=
[
input_
.
owner
.
inputs
[
1
+
j
]]
indices
.
append
(
slice
(
None
))
j
+=
1
output_dim
+=
1
subtensor_removed_dims
+=
1
# Verify the trailing dimensions the subtensor didn't look at.
return
[
x
[
tuple
(
indices
)]]
for
idx
in
range
(
len
(
input_
.
owner
.
op
.
idx_list
),
new_inputs
[
0
]
.
ndim
):
if
(
idx
-
subtensor_removed_dims
)
in
missing_dims
:
while
len
(
new_idx_list
)
<
idx
:
new_idx_list
.
append
(
slice
(
None
))
new_idx_list
.
append
(
zero
)
new_inputs
.
append
(
zero
)
return
[
Subtensor
(
new_idx_list
)(
*
new_inputs
)]
return
False
return
False
pytensor/tensor/subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
pytensor/tensor/variable.py
浏览文件 @
cc6bed1a
...
@@ -15,8 +15,9 @@ from pytensor.scalar import (
...
@@ -15,8 +15,9 @@ from pytensor.scalar import (
ComplexError
,
ComplexError
,
)
)
from
pytensor.tensor
import
_get_vector_length
from
pytensor.tensor
import
_get_vector_length
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
None
TypeT
from
pytensor.tensor.type_other
import
None
Const
from
pytensor.tensor.utils
import
hash_from_ndarray
from
pytensor.tensor.utils
import
hash_from_ndarray
...
@@ -454,14 +455,15 @@ class _tensor_py_operators:
...
@@ -454,14 +455,15 @@ class _tensor_py_operators:
elif
not
isinstance
(
args
,
tuple
):
elif
not
isinstance
(
args
,
tuple
):
args
=
(
args
,)
args
=
(
args
,)
# Count the dimensions, check for bools and find ellipses.
ellipses
=
[]
ellipses
=
[]
index_dim_count
=
0
index_dim_count
=
0
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
if
arg
is
None
or
(
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
# no increase in index_dim_count
):
pass
pass
elif
arg
is
Ellipsis
:
elif
arg
is
Ellipsis
:
# no increase in index_dim_count
ellipses
.
append
(
i
)
ellipses
.
append
(
i
)
elif
(
elif
(
isinstance
(
arg
,
np
.
ndarray
|
Variable
)
isinstance
(
arg
,
np
.
ndarray
|
Variable
)
...
@@ -503,41 +505,6 @@ class _tensor_py_operators:
...
@@ -503,41 +505,6 @@ class _tensor_py_operators:
self
.
ndim
-
index_dim_count
self
.
ndim
-
index_dim_count
)
)
if
any
(
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
))
for
arg
in
args
):
expansion_axes
=
[]
new_args
=
[]
# Track dims consumed by args and inserted `None`s after ellipsis
counter
=
0
nones
=
0
for
arg
in
args
:
if
arg
is
None
or
(
isinstance
(
arg
,
Variable
)
and
isinstance
(
arg
.
type
,
NoneTypeT
)
):
expansion_axes
.
append
(
counter
+
nones
)
# Expand here
nones
+=
1
new_args
.
append
(
slice
(
None
))
else
:
new_args
.
append
(
arg
)
consumed
=
1
if
hasattr
(
arg
,
"dtype"
)
and
arg
.
dtype
==
"bool"
:
consumed
=
arg
.
ndim
counter
+=
consumed
expanded
=
pt
.
expand_dims
(
self
,
expansion_axes
)
if
all
(
isinstance
(
arg
,
slice
)
and
arg
.
start
is
None
and
arg
.
stop
is
None
and
arg
.
step
is
None
for
arg
in
new_args
):
return
expanded
return
expanded
[
tuple
(
new_args
)]
def
is_empty_array
(
val
):
def
is_empty_array
(
val
):
return
(
isinstance
(
val
,
tuple
|
list
)
and
len
(
val
)
==
0
)
or
(
return
(
isinstance
(
val
,
tuple
|
list
)
and
len
(
val
)
==
0
)
or
(
isinstance
(
val
,
np
.
ndarray
)
and
val
.
size
==
0
isinstance
(
val
,
np
.
ndarray
)
and
val
.
size
==
0
...
@@ -553,16 +520,74 @@ class _tensor_py_operators:
...
@@ -553,16 +520,74 @@ class _tensor_py_operators:
for
inp
in
args
for
inp
in
args
)
)
if
all
(
# Determine if advanced indexing is needed or not. The logic is
(
# already in `index_vars_to_types`: if it succeeds, standard indexing is
isinstance
(
arg
,
slice
|
int
|
float
|
np
.
number
)
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
or
(
hasattr
(
arg
,
"ndim"
)
and
arg
.
ndim
==
0
and
arg
.
dtype
!=
"bool"
)
# used
)
advanced
=
False
for
arg
in
args
for
i
,
arg
in
enumerate
(
args
):
):
if
includes_bool
(
arg
):
return
pt
.
subtensor
.
basic_subtensor
(
self
,
*
args
)
advanced
=
True
else
:
break
if
arg
is
not
np
.
newaxis
and
arg
is
not
NoneConst
:
try
:
pt
.
subtensor
.
index_vars_to_types
(
arg
)
except
AdvancedIndexingError
:
if
advanced
:
break
else
:
advanced
=
True
if
advanced
:
return
pt
.
subtensor
.
advanced_subtensor
(
self
,
*
args
)
return
pt
.
subtensor
.
advanced_subtensor
(
self
,
*
args
)
else
:
if
np
.
newaxis
in
args
or
NoneConst
in
args
:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter
=
0
pattern
=
[]
new_args
=
[]
for
arg
in
args
:
if
arg
is
np
.
newaxis
or
arg
is
NoneConst
:
pattern
.
append
(
"x"
)
new_args
.
append
(
slice
(
None
,
None
,
None
))
else
:
pattern
.
append
(
counter
)
counter
+=
1
new_args
.
append
(
arg
)
pattern
.
extend
(
list
(
range
(
counter
,
self
.
ndim
)))
view
=
self
.
dimshuffle
(
pattern
)
full_slices
=
True
for
arg
in
new_args
:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if
not
(
isinstance
(
arg
,
slice
)
and
(
arg
.
start
is
None
or
arg
.
start
is
NoneConst
)
and
(
arg
.
stop
is
None
or
arg
.
stop
is
NoneConst
)
and
(
arg
.
step
is
None
or
arg
.
step
is
NoneConst
)
):
full_slices
=
False
if
full_slices
:
return
view
else
:
return
view
.
__getitem__
(
tuple
(
new_args
))
else
:
return
pt
.
subtensor
.
Subtensor
(
args
)(
self
,
*
pt
.
subtensor
.
get_slice_elements
(
args
,
lambda
entry
:
isinstance
(
entry
,
Variable
)
),
)
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
raise
TypeError
(
raise
TypeError
(
...
...
pytensor/xtensor/rewriting/indexing.py
浏览文件 @
cc6bed1a
...
@@ -2,10 +2,9 @@ from itertools import zip_longest
...
@@ -2,10 +2,9 @@ from itertools import zip_longest
from
pytensor
import
as_symbolic
from
pytensor
import
as_symbolic
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.tensor
import
arange
,
specify_shape
from
pytensor.tensor
import
TensorType
,
arange
,
specify_shape
from
pytensor.tensor.subtensor
import
_non_consecutive_adv_indexing
,
inc_subtensor
from
pytensor.tensor.subtensor
import
_non_consecutive_adv_indexing
,
inc_subtensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.indexing
import
Index
,
IndexUpdate
,
index
from
pytensor.xtensor.indexing
import
Index
,
IndexUpdate
,
index
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
...
@@ -107,7 +106,7 @@ def _lower_index(node):
...
@@ -107,7 +106,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs
.
append
(
to_basic_idx
(
idx
)
)
aligned_idxs
.
append
(
idx
)
basic_idx_axis
.
append
(
out_dims
.
index
(
x_dim
))
basic_idx_axis
.
append
(
out_dims
.
index
(
x_dim
))
else
:
else
:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
# Otherwise we need to convert the basic index into an equivalent advanced indexing
...
@@ -132,7 +131,7 @@ def _lower_index(node):
...
@@ -132,7 +131,7 @@ def _lower_index(node):
if
basic_idx_axis
:
if
basic_idx_axis
:
aligned_idxs
=
[
aligned_idxs
=
[
idx
.
squeeze
(
axis
=
basic_idx_axis
)
idx
.
squeeze
(
axis
=
basic_idx_axis
)
if
(
isinstance
(
idx
,
TensorVariabl
e
)
and
idx
.
type
.
ndim
>
0
)
if
(
isinstance
(
idx
.
type
,
TensorTyp
e
)
and
idx
.
type
.
ndim
>
0
)
else
idx
else
idx
for
idx
in
aligned_idxs
for
idx
in
aligned_idxs
]
]
...
...
tests/graph/rewriting/test_basic.py
浏览文件 @
cc6bed1a
...
@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
...
@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from
pytensor.raise_op
import
assert_op
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
from
pytensor.tensor.math
import
Dot
,
add
,
dot
,
exp
from
pytensor.tensor.rewriting.basic
import
constant_folding
from
pytensor.tensor.rewriting.basic
import
constant_folding
from
pytensor.tensor.subtensor
import
AdvancedSubtensor
from
pytensor.tensor.type
import
matrix
,
values_eq_approx_always_true
,
vector
from
pytensor.tensor.type
import
matrix
,
values_eq_approx_always_true
,
vector
from
pytensor.tensor.type_other
import
MakeSlice
,
SliceConstant
,
slicetype
from
tests.graph.utils
import
(
from
tests.graph.utils
import
(
MyOp
,
MyOp
,
MyType
,
MyType
,
...
@@ -627,6 +629,21 @@ def test_pre_constant_merge():
...
@@ -627,6 +629,21 @@ def test_pre_constant_merge():
assert
res
==
[
o2
]
assert
res
==
[
o2
]
assert
o2
.
owner
.
inputs
[
2
]
is
c2
assert
o2
.
owner
.
inputs
[
2
]
is
c2
# What is this supposed to test?
ms
=
MakeSlice
()(
1
)
res
=
pre_constant_merge
(
empty_fgraph
,
[
ms
])
assert
res
==
[
ms
]
const_slice
=
SliceConstant
(
type
=
slicetype
,
data
=
slice
(
1
,
None
,
2
))
assert
isinstance
(
const_slice
,
Constant
)
adv
=
AdvancedSubtensor
()(
matrix
(),
[
2
,
3
],
const_slice
)
res
=
pre_constant_merge
(
empty_fgraph
,
adv
)
assert
res
==
[
adv
]
def
test_pre_greedy_node_rewriter
():
def
test_pre_greedy_node_rewriter
():
empty_fgraph
=
FunctionGraph
([],
[])
empty_fgraph
=
FunctionGraph
([],
[])
...
@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
...
@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
0
]
is
o1
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
assert
cst
.
owner
.
inputs
[
4
]
is
cst
.
owner
.
inputs
[
0
]
# What exactly is this supposed to test?
ms
=
MakeSlice
()(
1
)
cst
=
pre_greedy_node_rewriter
(
empty_fgraph
,
[
constant_folding
],
ms
)
assert
isinstance
(
cst
,
SliceConstant
)
# Make sure constant of slice signature is hashable.
assert
isinstance
(
hash
(
cst
.
signature
()),
int
)
@pytest.mark.parametrize
(
"tracks"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"tracks"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"out_pattern"
,
[(
op2
,
"x"
),
"x"
,
1.0
])
@pytest.mark.parametrize
(
"out_pattern"
,
[(
op2
,
"x"
),
"x"
,
1.0
])
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
cc6bed1a
...
@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
...
@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
compare_jax_and_py
([],
[
out_pt
],
[])
compare_jax_and_py
([],
[
out_pt
],
[])
@pytest.mark.parametrize
(
"func"
,
(
pt_subtensor
.
advanced_inc_subtensor1
,
pt_subtensor
.
advanced_set_subtensor1
)
)
def
test_jax_AdvancedIncSubtensor1_runtime_broadcast
(
func
):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from
pytensor
import
function
y
=
pt
.
matrix
(
"y"
,
dtype
=
"float64"
,
shape
=
(
None
,
None
))
x
=
pt
.
zeros
((
10
,
5
))
idxs
=
np
.
repeat
(
np
.
arange
(
10
),
2
)
# 20 indices
out
=
func
(
x
,
y
,
idxs
)
f
=
function
([
y
],
out
,
mode
=
"JAX"
)
# Should work with correctly sized y
f
(
np
.
ones
((
20
,
5
)))
# Should raise for runtime broadcasting on first dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
1
,
5
)))
# Should raise for runtime broadcasting on second dimension
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed"
):
f
(
np
.
ones
((
20
,
1
)))
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
"""Setting or incrementing values with boolean indexing.
"""Setting or incrementing values with boolean indexing.
...
...
tests/link/mlx/test_subtensor.py
浏览文件 @
cc6bed1a
...
@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
...
@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
compare_mlx_and_py
([],
[
out_pt
],
[])
compare_mlx_and_py
([],
[
out_pt
],
[])
@pytest.mark.xfail
(
reason
=
"MLX slice indices must be integers or None, dynamic slices not supported"
)
def
test_mlx_MakeSlice
():
"""Test MakeSlice operation."""
# Test slice creation
start
=
pt
.
iscalar
(
"start"
)
stop
=
pt
.
iscalar
(
"stop"
)
step
=
pt
.
iscalar
(
"step"
)
# Create a slice using MakeSlice
slice_op
=
pt_subtensor
.
MakeSlice
()
slice_pt
=
slice_op
(
start
,
stop
,
step
)
# Use simple constant array instead of arange
x_pt
=
pt
.
constant
(
np
.
arange
(
10
,
dtype
=
np
.
float32
))
out_pt
=
x_pt
[
slice_pt
]
compare_mlx_and_py
([
start
,
stop
,
step
],
[
out_pt
],
[
1
,
8
,
2
])
def
test_mlx_subtensor_edge_cases
():
def
test_mlx_subtensor_edge_cases
():
"""Test edge cases and boundary conditions."""
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
# Empty slices - use constant array
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
cc6bed1a
...
@@ -3,7 +3,9 @@ import contextlib
...
@@ -3,7 +3,9 @@ import contextlib
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
Mode
,
as_symbolic
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
...
@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
...
@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
inc_subtensor
,
inc_subtensor
,
set_subtensor
,
set_subtensor
,
)
)
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_mode
compare_numba_and_py
,
numba_inplace_mode
,
numba_mode
,
)
rng
=
np
.
random
.
default_rng
(
sum
(
map
(
ord
,
"Numba subtensors"
)))
rng
=
np
.
random
.
default_rng
(
sum
(
map
(
ord
,
"Numba subtensors"
)))
@pytest.mark.parametrize
(
"step"
,
[
None
,
1
,
2
,
-
2
,
"x"
],
ids
=
lambda
x
:
f
"step={x}"
)
@pytest.mark.parametrize
(
"stop"
,
[
None
,
10
,
"x"
],
ids
=
lambda
x
:
f
"stop={x}"
)
@pytest.mark.parametrize
(
"start"
,
[
None
,
0
,
3
,
"x"
],
ids
=
lambda
x
:
f
"start={x}"
)
def
test_slice
(
start
,
stop
,
step
):
x
=
ps
.
int64
(
"x"
)
sym_slice
=
as_symbolic
(
slice
(
x
if
start
==
"x"
else
start
,
x
if
stop
==
"x"
else
stop
,
x
if
step
==
"x"
else
step
,
)
)
no_opt_mode
=
Mode
(
linker
=
"numba"
,
optimizer
=
None
)
evaled_slice
=
sym_slice
.
eval
({
x
:
-
5
},
on_unused_input
=
"ignore"
,
mode
=
no_opt_mode
)
assert
isinstance
(
evaled_slice
,
slice
)
if
start
==
"x"
:
assert
evaled_slice
.
start
==
-
5
elif
start
is
None
and
(
evaled_slice
.
step
is
None
or
evaled_slice
.
step
>
0
):
# Numba can convert to 0 (and sometimes does) in this case
assert
evaled_slice
.
start
in
(
None
,
0
)
else
:
assert
evaled_slice
.
start
==
start
if
stop
==
"x"
:
assert
evaled_slice
.
stop
==
-
5
else
:
assert
evaled_slice
.
stop
==
stop
if
step
==
"x"
:
assert
evaled_slice
.
step
==
-
5
elif
step
is
None
:
# Numba can convert to 1 (and sometimes does) in this case
assert
evaled_slice
.
step
in
(
None
,
1
)
else
:
assert
evaled_slice
.
step
==
step
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, indices"
,
"x, indices"
,
[
[
...
@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
...
@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
slice
(
1
,
None
),
[[
0
,
0
],
[
0
,
0
]]),
([[
1
,
2
],
[
2
,
1
]],
slice
(
1
,
None
),
[[
0
,
0
],
[
0
,
0
]]),
),
),
# Newaxis with vector indexing
(
as_tensor
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
))),
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
),
],
],
)
)
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
...
@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
...
@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False
,
False
,
False
,
False
,
),
),
(
np
.
arange
(
4
*
4
)
.
reshape
((
4
,
4
)),
np
.
array
(
5
),
# Broadcasted scalar value
(
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]),
# Newaxis with vector indexing
False
,
False
,
),
],
],
)
)
@pytest.mark.parametrize
(
"inplace"
,
(
False
,
True
))
@pytest.mark.parametrize
(
"inplace"
,
(
False
,
True
))
...
@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
...
@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
inplace
,
inplace
,
):
):
# Need rewrite to support certain forms of advanced indexing without object mode
# Need rewrite to support certain forms of advanced indexing without object mode
# Use inplace_mode when testing inplace operations to preserve inplace flag
mode
=
numba_mode
.
including
(
"specialize"
)
base_mode
=
numba_inplace_mode
if
inplace
else
numba_mode
mode
=
base_mode
.
including
(
"specialize"
)
x_pt
=
pt
.
as_tensor
(
x
)
.
type
(
"x"
)
x_pt
=
pt
.
as_tensor
(
x
)
.
type
(
"x"
)
y_pt
=
pt
.
as_tensor
(
y
)
.
type
(
"y"
)
y_pt
=
pt
.
as_tensor
(
y
)
.
type
(
"y"
)
...
@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
...
@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig
=
x
.
copy
()
x_orig
=
x
.
copy
()
fn
(
x
,
y
)
fn
(
x
,
y
)
assert
not
np
.
all
(
x
==
x_orig
)
assert
not
np
.
all
(
x
==
x_orig
)
def
test_advanced_indexing_with_newaxis_fallback_obj_mode
():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x
=
pt
.
matrix
(
"x"
)
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
.
inc
(
5
)
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedIncSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
cc6bed1a
...
@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
...
@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
# Save original value to restore later
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
1
original_value
=
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
with
pytest
.
warns
(
try
:
FutureWarning
,
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
1
match
=
"tensor__insert_inplace_optimizer_validate_nb config is deprecated"
,
with
pytest
.
warns
(
):
FutureWarning
,
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
match
=
"tensor__insert_inplace_optimizer_validate_nb config is deprecated"
,
):
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
finally
:
# Restore original value to avoid affecting other tests
pytensor
.
config
.
tensor__insert_inplace_optimizer_validate_nb
=
original_value
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
cc6bed1a
...
@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
...
@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
tensor4
,
tensor4
,
vector
,
vector
,
)
)
from
pytensor.tensor.type_other
import
make_slice
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.unittest_tools
import
create_pytensor_param
from
tests.unittest_tools
import
create_pytensor_param
...
@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
...
@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
assert
isinstance
(
new_index
,
Constant
)
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
assert
new_index
.
type
.
dtype
==
"uint8"
# `AdvancedSubtensor`, two indices, one slice, convert
# `AdvancedSubtensor`, two indices, one s
ymbolic s
lice, convert
x
=
pt
.
matrix
(
"x"
)
x
=
pt
.
matrix
(
"x"
)
indices
=
(
indices
=
(
pt
.
as_tensor_variable
(
np
.
array
(
[
1
]
,
np
.
int64
)),
pt
.
as_tensor_variable
(
np
.
array
(
1
,
np
.
int64
)),
slice
(
None
,
10
),
make_slice
(
slice
(
None
,
10
)
),
)
)
z
=
x
[
indices
]
z
=
x
[
indices
]
...
@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
...
@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
z_fn
=
pytensor
.
function
([
x
],
z
,
mode
=
mode
)
z_fn
=
pytensor
.
function
([
x
],
z
,
mode
=
mode
)
subtensor_node
=
z_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
subtensor_node
=
z_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
subtensor_node
.
op
,
(
AdvancedSubtensor
,
AdvancedSubtensor1
)
)
assert
isinstance
(
subtensor_node
.
op
,
AdvancedSubtensor
)
new_index
=
subtensor_node
.
inputs
[
1
]
new_index
=
subtensor_node
.
inputs
[
1
]
assert
isinstance
(
new_index
,
Constant
)
assert
isinstance
(
new_index
,
Constant
)
assert
new_index
.
type
.
dtype
==
"uint8"
assert
new_index
.
type
.
dtype
==
"uint8"
...
@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
...
@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
out
=
vectorize_graph
(
core_graph
,
replace
=
{
core_x
:
x
,
core_y
:
y
})
out
=
vectorize_graph
(
core_graph
,
replace
=
{
core_x
:
x
,
core_y
:
y
})
fn
,
ref_fn
=
self
.
compile_fn_and_ref
([
x
,
y
],
out
)
fn
,
ref_fn
=
self
.
compile_fn_and_ref
([
x
,
y
],
out
)
assert
self
.
has_blockwise
(
ref_fn
)
assert
self
.
has_blockwise
(
ref_fn
)
assert
not
self
.
has_blockwise
(
fn
)
test_x
=
np
.
ones
(
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
test_x
=
np
.
ones
(
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
)
test_y
=
rng
.
integers
(
1
,
10
,
size
=
y
.
type
.
shape
,
dtype
=
y
.
type
.
dtype
)
test_y
=
rng
.
integers
(
1
,
10
,
size
=
y
.
type
.
shape
,
dtype
=
y
.
type
.
dtype
)
np
.
testing
.
assert_allclose
(
fn
(
test_x
,
test_y
),
ref_fn
(
test_x
,
test_y
))
np
.
testing
.
assert_allclose
(
fn
(
test_x
,
test_y
),
ref_fn
(
test_x
,
test_y
))
...
@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
...
@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"basic_idx"
,
"basic_idx"
,
[
True
,
False
],
[
True
,
pytest
.
param
(
False
,
marks
=
pytest
.
mark
.
xfail
(
reason
=
"AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids
=
[
"basic_idx"
,
"adv_idx"
],
ids
=
[
"basic_idx"
,
"adv_idx"
],
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
...
@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
core_idx
=
pt
.
tensor
(
"idx"
,
dtype
=
int
,
shape
=
()
if
basic_idx
else
(
2
,))
core_idx
=
pt
.
tensor
(
"idx"
,
dtype
=
int
,
shape
=
()
if
basic_idx
else
(
2
,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with a new arange slice on the batched dimensions.
# once it is paired with a
n
new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out
=
core_a
[
0
,
:,
core_idx
]
.
set
(
core_v
)
core_out
=
core_a
[
0
,
:,
core_idx
]
.
set
(
core_v
)
...
...
tests/tensor/rewriting/test_subtensor_lift.py
浏览文件 @
cc6bed1a
...
@@ -32,6 +32,7 @@ from pytensor.tensor import (
...
@@ -32,6 +32,7 @@ from pytensor.tensor import (
lscalars
,
lscalars
,
matrix
,
matrix
,
shape
,
shape
,
slicetype
,
specify_shape
,
specify_shape
,
tensor
,
tensor
,
tensor3
,
tensor3
,
...
@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
...
@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
(
(
matrix
(),
matrix
(),
(
iscalar
(),
iscalar
()),
(
iscalar
(),
iscalar
()),
(
slice
(
iscalar
(),
iscalar
(),
iscalar
()
),),
(
slice
type
(
),),
),
),
(
(
matrix
(),
matrix
(),
...
@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
...
@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
(
lambda
x
:
x
[:,
[
0
,
1
]][
0
],
True
),
(
lambda
x
:
x
[:,
[
0
,
1
]][
0
],
True
),
(
lambda
x
:
x
[:,
[
0
,
1
],
[
0
,
0
]][
1
:],
True
),
(
lambda
x
:
x
[:,
[
0
,
1
],
[
0
,
0
]][
1
:],
True
),
(
lambda
x
:
x
[:,
[[
0
,
1
],
[
0
,
0
]]][
1
:],
True
),
(
lambda
x
:
x
[:,
[[
0
,
1
],
[
0
,
0
]]][
1
:],
True
),
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
True
),
# Not supported, basic indexing on advanced indexing dim
# Not supported, basic indexing on advanced indexing dim
(
lambda
x
:
x
[[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
]][
0
],
False
),
# Not
suppor
ted, basic indexing on the right of advanced indexing
# Not
implemen
ted, basic indexing on the right of advanced indexing
(
lambda
x
:
x
[[
0
,
1
]][:,
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
]][:,
0
],
False
),
# Not implemented, complex flavors of advanced indexing
# Not implemented, complex flavors of advanced indexing
(
lambda
x
:
x
[:,
None
,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
5
:,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
5
:,
[
0
,
1
]][
0
],
False
),
(
lambda
x
:
x
[:,
:,
np
.
array
([
True
,
False
,
False
])][
0
],
False
),
(
lambda
x
:
x
[:,
:,
np
.
array
([
True
,
False
,
False
])][
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
],
:,
[
0
,
1
]][:,
0
],
False
),
(
lambda
x
:
x
[[
0
,
1
],
:,
[
0
,
1
]][:,
0
],
False
),
...
...
tests/tensor/test_blockwise.py
浏览文件 @
cc6bed1a
...
@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
...
@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback
,
vectorize_node_fallback
,
)
)
from
pytensor.tensor.nlinalg
import
MatrixInverse
,
eig
from
pytensor.tensor.nlinalg
import
MatrixInverse
,
eig
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.op
import
default_rng
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.signal
import
convolve1d
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
...
@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
...
@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
def
test_vectorize_node_fallback_unsupported_type
():
def
test_vectorize_node_fallback_unsupported_type
():
rng
=
default_rng
(
)
x
=
tensor
(
"x"
,
shape
=
(
2
,
6
)
)
node
=
normal
(
rng
=
rng
)
.
owner
node
=
x
[:,
[
0
,
2
,
4
]]
.
owner
with
pytest
.
raises
(
with
pytest
.
raises
(
NotImplementedError
,
NotImplementedError
,
match
=
re
.
escape
(
match
=
re
.
escape
(
'Cannot vectorize node normal_rv{"(),()->()"}('
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
),
),
):
):
vectorize_node_fallback
(
node
.
op
,
node
,
*
node
.
inputs
)
vectorize_node_fallback
(
node
.
op
,
node
,
node
.
inputs
)
def
check_blockwise_runtime_broadcasting
(
mode
):
def
check_blockwise_runtime_broadcasting
(
mode
):
...
...
tests/tensor/test_subtensor.py
浏览文件 @
cc6bed1a
差异被折叠。
点击展开。
tests/tensor/test_type_other.py
浏览文件 @
cc6bed1a
...
@@ -4,8 +4,30 @@ import pytensor
...
@@ -4,8 +4,30 @@ import pytensor
from
pytensor
import
as_symbolic
from
pytensor
import
as_symbolic
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.tensor.math
import
argmax
from
pytensor.tensor.math
import
argmax
from
pytensor.tensor.type
import
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
from
pytensor.tensor.type_other
import
(
MakeSlice
,
NoneConst
,
NoneTypeT
,
SliceConstant
,
SliceType
,
make_slice
,
)
def
test_SliceType
():
st
=
SliceType
()
assert
st
==
st
.
clone
()
def
test_make_slice_merge
():
# In the past, this was crahsing during compilation.
i
=
iscalar
()
s1
=
make_slice
(
0
,
i
)
s2
=
make_slice
(
0
,
i
)
f
=
pytensor
.
function
([
i
],
[
s1
,
s2
])
nodes
=
f
.
maker
.
fgraph
.
apply_nodes
assert
len
([
n
for
n
in
nodes
if
isinstance
(
n
.
op
,
MakeSlice
)])
==
1
def
test_none_Constant
():
def
test_none_Constant
():
...
@@ -25,6 +47,8 @@ def test_none_Constant():
...
@@ -25,6 +47,8 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past.
# This trigger equals that returned the wrong answer in the past.
import
pickle
import
pickle
import
pytensor
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
y
=
argmax
(
x
)
y
=
argmax
(
x
)
kwargs
=
{}
kwargs
=
{}
...
@@ -36,18 +60,11 @@ def test_none_Constant():
...
@@ -36,18 +60,11 @@ def test_none_Constant():
def
test_as_symbolic
():
def
test_as_symbolic
():
# Remove this when xtensor is not using symbolic slices
from
pytensor.tensor.type
import
iscalar
from
pytensor.tensor.type_other
import
SliceConstant
,
slicetype
res
=
as_symbolic
(
None
)
res
=
as_symbolic
(
None
)
assert
res
is
NoneConst
assert
res
is
NoneConst
res
=
as_symbolic
(
slice
(
iscalar
()))
assert
res
.
owner
.
op
==
make_slice
res
=
as_symbolic
(
slice
(
1
,
2
))
res
=
as_symbolic
(
slice
(
1
,
2
))
assert
isinstance
(
res
,
SliceConstant
)
assert
isinstance
(
res
,
SliceConstant
)
assert
res
.
type
==
slicetype
assert
res
.
data
==
slice
(
1
,
2
)
i
=
iscalar
()
res
=
as_symbolic
(
slice
(
i
))
assert
res
.
owner
is
not
None
tests/tensor/test_variable.py
浏览文件 @
cc6bed1a
...
@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
...
@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar
,
scalar
,
tensor3
,
tensor3
,
)
)
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneConst
from
pytensor.tensor.variable
import
(
from
pytensor.tensor.variable
import
(
DenseTensorConstant
,
DenseTensorConstant
,
DenseTensorVariable
,
DenseTensorVariable
,
...
@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
...
@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z
=
x
[:,
i
]
z
=
x
[:,
i
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlice
,
AdvancedSubtensor
]
z
=
x
[
...
,
i
,
None
]
z
=
x
[
...
,
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
assert
op_types
==
[
DimShuffl
e
,
AdvancedSubtensor
]
assert
op_types
==
[
MakeSlic
e
,
AdvancedSubtensor
]
z
=
x
[
i
,
None
]
z
=
x
[
i
,
None
]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
op_types
=
[
type
(
node
.
op
)
for
node
in
io_toposort
([
x
,
i
],
[
z
])]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论