Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
43d8e303
提交
43d8e303
authored
6月 09, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 10, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite ExtractDiagonal of AllocDiagonal
上级
f7c4e163
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
132 行增加
和
1 行删除
+132
-1
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+95
-0
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+37
-1
没有找到文件。
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
43d8e303
...
@@ -19,6 +19,7 @@ from pytensor.scalar import Add, ScalarConstant, ScalarType
...
@@ -19,6 +19,7 @@ from pytensor.scalar import Add, ScalarConstant, ScalarType
from
pytensor.scalar
import
constant
as
scalar_constant
from
pytensor.scalar
import
constant
as
scalar_constant
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
ExtractDiag
,
Join
,
Join
,
ScalarFromTensor
,
ScalarFromTensor
,
TensorFromScalar
,
TensorFromScalar
,
...
@@ -26,6 +27,7 @@ from pytensor.tensor.basic import (
...
@@ -26,6 +27,7 @@ from pytensor.tensor.basic import (
cast
,
cast
,
concatenate
,
concatenate
,
expand_dims
,
expand_dims
,
full
,
get_scalar_constant_value
,
get_scalar_constant_value
,
get_underlying_scalar_constant_value
,
get_underlying_scalar_constant_value
,
register_infer_shape
,
register_infer_shape
,
...
@@ -1793,3 +1795,96 @@ optdb["specialize"].register(
...
@@ -1793,3 +1795,96 @@ optdb["specialize"].register(
"numba"
,
"numba"
,
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
)
)
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter
([
ExtractDiag
])
def
extract_diag_of_diagonal_set_subtensor
(
fgraph
,
node
):
"""Undo extract diagonal from a set diagonal
This rewrites the following pattern:
y = write_diagonal*(x, x_diag, offset=k1)
z = extract_diag(y, offset=k2)
as:
z = diag_x, if k1 == k2
z = x if k1 != k2
* write_diagonal is not an atomic operation, but a sequence of Arange/SetSubtensor operations.
"""
def
is_cosntant_arange
(
var
)
->
bool
:
if
not
(
isinstance
(
var
,
TensorConstant
)
and
var
.
type
.
ndim
==
1
):
return
False
data
=
var
.
data
start
,
stop
=
data
[
0
],
data
[
-
1
]
+
1
return
data
.
size
==
(
stop
-
start
)
and
(
data
==
np
.
arange
(
start
,
stop
))
.
all
()
# type: ignore
[
diag_x
]
=
node
.
inputs
if
not
(
diag_x
.
owner
is
not
None
and
isinstance
(
diag_x
.
owner
.
op
,
AdvancedIncSubtensor
)
and
diag_x
.
owner
.
op
.
set_instead_of_inc
):
return
None
x
,
y
,
*
idxs
=
diag_x
.
owner
.
inputs
if
not
(
x
.
type
.
ndim
>=
2
and
None
not
in
x
.
type
.
shape
[
-
2
:]
and
x
.
type
.
shape
[
-
2
]
==
x
.
type
.
shape
[
-
1
]
):
# TODO: for now we only support rewrite with static square shape for x
return
None
op
=
node
.
op
if
op
.
axis2
>
len
(
idxs
):
return
None
# Check all non-axis indices are full slices
axis
=
{
op
.
axis1
,
op
.
axis2
}
if
not
all
(
is_full_slice
(
idx
)
for
i
,
idx
in
enumerate
(
idxs
)
if
i
not
in
axis
):
return
None
# Check axis indices are arange we would expect from setting on the diagonal
axis1_idx
=
idxs
[
op
.
axis1
]
axis2_idx
=
idxs
[
op
.
axis2
]
if
not
(
is_cosntant_arange
(
axis1_idx
)
and
is_cosntant_arange
(
axis2_idx
)):
return
None
dim_length
=
x
.
type
.
shape
[
-
1
]
offset
=
op
.
offset
start_stop1
=
(
axis1_idx
.
data
[
0
],
axis1_idx
.
data
[
-
1
]
+
1
)
start_stop2
=
(
axis2_idx
.
data
[
0
],
axis2_idx
.
data
[
-
1
]
+
1
)
orig_start1
,
orig_start2
=
start_stop1
[
0
],
start_stop2
[
0
]
if
offset
<
0
:
# The logic for checking if we are selecting or not a diagonal for negative offset is the same
# as the one with positive offset but swapped axis
start_stop1
,
start_stop2
=
start_stop2
,
start_stop1
offset
=
-
offset
start1
,
stop1
=
start_stop1
start2
,
stop2
=
start_stop2
if
(
start1
==
0
and
start2
==
offset
and
stop1
==
dim_length
-
offset
and
stop2
==
dim_length
):
# We are extracting the just written diagonal
if
y
.
type
.
ndim
==
0
or
y
.
type
.
shape
[
-
1
]
==
1
:
# We may need to broadcast y
y
=
full
((
*
x
.
shape
[:
-
2
],
dim_length
-
offset
),
y
,
dtype
=
x
.
type
.
dtype
)
return
[
y
]
elif
(
orig_start2
-
orig_start1
)
!=
op
.
offset
:
# Some other diagonal was written, ignore it
return
[
op
(
x
)]
else
:
# A portion, but no the whole diagonal was written, don't do anything
return
None
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
43d8e303
import
random
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -9,7 +11,7 @@ from pytensor.compile.function import function
...
@@ -9,7 +11,7 @@ from pytensor.compile.function import function
from
pytensor.compile.mode
import
Mode
,
get_default_mode
,
get_mode
from
pytensor.compile.mode
import
Mode
,
get_default_mode
,
get_mode
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
vectorize_graph
from
pytensor.graph
import
rewrite_graph
,
vectorize_graph
from
pytensor.graph.basic
import
Constant
,
Variable
,
ancestors
,
equal_computations
from
pytensor.graph.basic
import
Constant
,
Variable
,
ancestors
,
equal_computations
from
pytensor.graph.rewriting.basic
import
check_stack_trace
from
pytensor.graph.rewriting.basic
import
check_stack_trace
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
...
@@ -1956,3 +1958,37 @@ class TestUselessSlice:
...
@@ -1956,3 +1958,37 @@ class TestUselessSlice:
f
(
test_x
,
-
2
),
f
(
test_x
,
-
2
),
test_x
[
0
:
3
:
-
2
,
-
1
:
-
6
:
2
,
::],
test_x
[
0
:
3
:
-
2
,
-
1
:
-
6
:
2
,
::],
)
)
def
test_extract_diag_of_diagonal_set_subtensor
():
A
=
pt
.
full
((
2
,
6
,
6
),
np
.
nan
)
rows
=
pt
.
arange
(
A
.
shape
[
-
2
])
cols
=
pt
.
arange
(
A
.
shape
[
-
1
])
write_offsets
=
[
-
2
,
-
1
,
0
,
1
,
2
]
# Randomize order of write operations, to make sure rewrite is not sensitive to it
random
.
shuffle
(
write_offsets
)
for
offset
in
write_offsets
:
value
=
offset
+
0.1
*
offset
if
offset
==
0
:
A
=
A
[
...
,
rows
,
cols
]
.
set
(
value
)
elif
offset
>
0
:
A
=
A
[
...
,
rows
[:
-
offset
],
cols
[
offset
:]]
.
set
(
value
)
else
:
offset
=
-
offset
A
=
A
[
...
,
rows
[
offset
:],
cols
[:
-
offset
]]
.
set
(
value
)
# Add a partial diagonal along offset 3
A
=
A
[
...
,
rows
[
1
:
-
3
],
cols
[
4
:]]
.
set
(
np
.
pi
)
read_offsets
=
[
-
2
,
-
1
,
0
,
1
,
2
,
3
]
outs
=
[
A
.
diagonal
(
offset
=
offset
,
axis1
=-
2
,
axis2
=-
1
)
for
offset
in
read_offsets
]
rewritten_outs
=
rewrite_graph
(
outs
,
include
=
(
"ShapeOpt"
,
"canonicalize"
))
# Every output should just be an Alloc with value
expected_outs
=
[]
for
offset
in
read_offsets
[:
-
1
]:
value
=
np
.
asarray
(
offset
+
0.1
*
offset
,
dtype
=
A
.
type
.
dtype
)
expected_outs
.
append
(
pt
.
full
((
np
.
int64
(
2
),
np
.
int8
(
6
-
abs
(
offset
))),
value
))
# The partial diagonal shouldn't be rewritten
expected_outs
.
append
(
outs
[
-
1
])
assert
equal_computations
(
rewritten_outs
,
expected_outs
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论