Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1ebd078a
提交
1ebd078a
authored
6月 23, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
7月 02, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reuse Elemwise inplace machinery for Blockwise
上级
1d94ed68
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
169 行增加
和
81 行删除
+169
-81
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+28
-18
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+75
-61
test_blockwise.py
tests/tensor/test_blockwise.py
+66
-2
没有找到文件。
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
1ebd078a
...
...
@@ -2,7 +2,7 @@ from pytensor.compile.mode import optdb
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph.destroyhandler
import
inplace_candidates
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
out2in
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
...
...
@@ -11,6 +11,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize
,
register_stabilize
,
)
from
pytensor.tensor.rewriting.elemwise
import
InplaceGraphOptimizer
from
pytensor.tensor.shape
import
Reshape
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
...
...
@@ -260,19 +261,15 @@ def local_blockwise_of_subtensor(fgraph, node):
return
[
x
[(
*
none_slices
,
*
core_idxs
)]]
@node_rewriter
(
tracks
=
[
Blockwise
],
inplace
=
True
)
def
blockwise_inplace
(
fgraph
,
node
):
blockwise_op
=
node
.
op
if
blockwise_op
.
destroy_map
:
# Op already has inplace
return
class
InplaceBlockwiseOptimizer
(
InplaceGraphOptimizer
):
op
=
Blockwise
# Find out valid inputs for inplacing
def
filter_candidate_pairs
(
self
,
fgraph
,
node
,
protected_inputs
):
blockwise_op
=
node
.
op
batch_ndim
=
blockwise_op
.
batch_ndim
(
node
)
out_batch_bcast
=
node
.
outputs
[
0
]
.
type
.
broadcastable
[:
batch_ndim
]
inputs
=
node
.
inputs
candidate_inputs
=
set
(
inplace_candidates
(
fgraph
,
...
...
@@ -281,21 +278,36 @@ def blockwise_inplace(fgraph, node):
for
inp
in
inputs
if
inp
.
type
.
broadcastable
[:
batch_ndim
]
==
out_batch_bcast
],
protected_inputs
=
protected_inputs
,
)
)
allowed_inplace_inputs
=
[
i
for
i
,
inp
in
enumerate
(
inputs
)
if
inp
in
candidate_inputs
]
destroy_map
=
blockwise_op
.
core_op
.
inplace_on_inputs
(
allowed_inplace_inputs
=
allowed_inplace_inputs
)
.
destroy_map
if
not
allowed_inplace_inputs
:
return
None
if
not
destroy_map
:
return
[]
outputs
=
node
.
outputs
return
[
((
out_idx
,
outputs
[
out_idx
]),
(
inp_idx
,
inputs
[
inp_idx
]))
for
out_idx
,
inp_idxs
in
destroy_map
.
items
()
for
inp_idx
in
inp_idxs
]
def
create_inplace_node
(
self
,
node
,
inplace_pattern
):
blockwise_op
=
node
.
op
allowed_inplace_inputs
=
tuple
(
v
[
0
]
for
v
in
inplace_pattern
.
values
())
inplace_core_op
=
blockwise_op
.
core_op
.
inplace_on_inputs
(
allowed_inplace_inputs
=
allowed_inplace_inputs
)
if
not
inplace_core_op
.
destroy_map
:
return
Non
e
return
nod
e
# Check Op is not trying to inplace on non-candidate inputs
for
destroyed_inputs
in
inplace_core_op
.
destroy_map
.
values
():
...
...
@@ -306,7 +318,7 @@ def blockwise_inplace(fgraph, node):
)
# Recreate core_op with inplace
inplace_blockwise_op
=
Blockwise
(
inplace_blockwise_op
=
type
(
blockwise_op
)
(
core_op
=
inplace_core_op
,
signature
=
blockwise_op
.
signature
,
name
=
blockwise_op
.
name
,
...
...
@@ -314,14 +326,12 @@ def blockwise_inplace(fgraph, node):
destroy_map
=
inplace_core_op
.
destroy_map
,
)
out
=
inplace_blockwise_op
.
make_node
(
*
node
.
inputs
)
.
outputs
copy_stack_trace
(
node
.
outputs
,
out
)
return
out
return
inplace_blockwise_op
.
make_node
(
*
node
.
inputs
)
optdb
.
register
(
"blockwise_inplace"
,
in2out
(
blockwise_inplace
),
InplaceBlockwiseOptimizer
(
),
"fast_run"
,
"inplace"
,
position
=
50.1
,
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
1ebd078a
import
abc
import
itertools
import
operator
import
sys
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Generator
from
collections.abc
import
Generator
,
Sequence
from
functools
import
cache
,
reduce
from
typing
import
TypeVar
from
warnings
import
warn
...
...
@@ -12,7 +13,7 @@ from pytensor import clone_replace, compile
from
pytensor.compile.function.types
import
Supervisor
from
pytensor.compile.mode
import
get_target_language
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
,
Op
from
pytensor.graph.basic
import
Apply
,
Variable
,
ancestors
from
pytensor.graph.destroyhandler
import
DestroyHandler
,
inplace_candidates
from
pytensor.graph.features
import
ReplaceValidate
...
...
@@ -47,22 +48,31 @@ from pytensor.tensor.shape import shape_padleft
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
class
InplaceElemwiseOptimizer
(
GraphRewriter
):
r"""
This is parameterized so that it works for `Elemwise` `Op`\s.
"""
class
InplaceGraphOptimizer
(
GraphRewriter
):
op
:
type
[
Op
]
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
DestroyHandler
())
@abc.abstractmethod
def
filter_candidate_pairs
(
self
,
fgraph
:
FunctionGraph
,
node
:
Apply
,
protected_inputs
:
Sequence
[
Variable
]
)
->
Sequence
[
tuple
[
tuple
[
int
,
Variable
],
tuple
[
int
,
Variable
]]]:
pass
@abc.abstractmethod
def
create_inplace_node
(
self
,
node
:
Apply
,
inplace_pattern
:
dict
[
int
,
Sequence
[
int
]]
)
->
Apply
:
pass
def
apply
(
self
,
fgraph
):
r"""
Attempts to replace all `Elemwise`\s by versions of them that operate
inplace. It operates greedily: for each `Elemwise` that is encountered,
for each output, it tries each input to see if it can operate inplace
on that input. If so, it makes the change and goes to the next output
or `Elemwise`.
Attempts to replace all `Op`\s by versions of them that operate
inplace. It operates greedily: for each `Op` that is encountered,
it tries to inplace all the valid inputs at once (if the Op supports it),
if that fails, it tries to inplace one input at a time.
Examples
--------
...
...
@@ -93,36 +103,13 @@ class InplaceElemwiseOptimizer(GraphRewriter):
# tackle them in a more general way. The whole try/except approach is probably suboptimal.
# We can consider restricting inputs with static shapes that are large enough.
def
create_inplace_node
(
node
,
inplace_pattern
):
op
=
node
.
op
scalar_op
=
op
.
scalar_op
inplace_pattern
=
{
i
:
o
for
i
,
[
o
]
in
inplace_pattern
.
items
()}
if
hasattr
(
scalar_op
,
"make_new_inplace"
):
new_scalar_op
=
scalar_op
.
make_new_inplace
(
ps
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
o
.
dtype
)
for
i
,
o
in
enumerate
(
node
.
outputs
)
]
)
)
else
:
new_scalar_op
=
type
(
scalar_op
)(
ps
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
range
(
len
(
node
.
outputs
))
]
)
)
return
type
(
op
)(
new_scalar_op
,
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
if
config
.
tensor__insert_inplace_optimizer_validate_nb
!=
-
1
:
warn
(
"tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release."
,
FutureWarning
,
)
reason
=
f
"{self.op}_inplace_optimizer"
prof
=
{
"opt"
:
self
,
"node_before"
:
len
(
fgraph
.
apply_nodes
),
...
...
@@ -140,6 +127,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
protected_inputs
.
update
(
fgraph
.
outputs
)
root_destroyer
=
fgraph
.
destroy_handler
.
root_destroyer
self_op
=
self
.
op
update_mapping
=
fgraph
.
update_mapping
or
{}
op_updates
:
dict
[
TensorVariable
,
TensorVariable
]
=
{
out
:
fgraph
.
inputs
[
update_mapping
[
out_idx
]]
...
...
@@ -147,36 +135,22 @@ class InplaceElemwiseOptimizer(GraphRewriter):
if
(
out_idx
in
update_mapping
and
out
.
owner
and
isinstance
(
out
.
owner
.
op
,
Elemwise
)
and
isinstance
(
out
.
owner
.
op
,
self_op
)
)
}
set_op_updates
=
set
(
op_updates
.
keys
())
for
node
in
fgraph
.
toposort
():
if
not
isinstance
(
node
.
op
,
Elemwise
)
or
node
.
op
.
destroy_map
:
if
not
isinstance
(
node
.
op
,
self_op
)
or
node
.
op
.
destroy_map
:
continue
# If big graph and the outputs are scalar, do not make it inplace.
if
large_graph
and
all
(
node
.
outputs
[
0
]
.
type
.
broadcastable
):
continue
candidate_inputs
=
[
(
node
.
inputs
.
index
(
inp
),
inp
)
for
inp
in
inplace_candidates
(
fgraph
,
node
.
inputs
,
protected_inputs
=
protected_inputs
,
candidate_pairs
=
self
.
filter_candidate_pairs
(
fgraph
,
node
,
protected_inputs
)
]
if
not
candidate_inputs
:
return
[]
candidate_pairs
=
[
((
o
,
out
),
(
i
,
inp
))
for
o
,
out
in
enumerate
(
node
.
outputs
)
for
i
,
inp
in
candidate_inputs
if
inp
.
type
==
out
.
type
]
if
not
candidate_pairs
:
continue
...
...
@@ -216,13 +190,11 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern
[
o
]
=
[
i
]
tried_inputs
.
add
(
i
)
inplace_node
=
create_inplace_node
(
node
,
inplace_pattern
)
inplace_node
=
self
.
create_inplace_node
(
node
,
inplace_pattern
)
if
inplace_node
.
op
.
destroy_map
==
inplace_pattern
:
replacements
=
tuple
(
zip
(
node
.
outputs
,
inplace_node
.
outputs
))
try
:
fgraph
.
replace_all_validate
(
replacements
,
reason
=
"inplace_elemwise_optimizer"
)
fgraph
.
replace_all_validate
(
replacements
,
reason
=
reason
)
except
InconsistencyError
:
prof
[
"nb_eager_inconsistent"
]
+=
1
else
:
...
...
@@ -238,7 +210,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
inplace_pattern
[
o
]
=
[
i
]
tried_inputs
.
add
(
i
)
inplace_node
=
create_inplace_node
(
node
,
inplace_pattern
)
inplace_node
=
self
.
create_inplace_node
(
node
,
inplace_pattern
)
if
inplace_node
.
op
.
destroy_map
!=
inplace_pattern
:
# This Op can't respect this partial inplace pattern,
# We assume it can't support any other cases
...
...
@@ -246,9 +218,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
else
:
replacements
=
tuple
(
zip
(
node
.
outputs
,
inplace_node
.
outputs
))
try
:
fgraph
.
replace_all_validate
(
replacements
,
reason
=
"inplace_elemwise_optimizer"
)
fgraph
.
replace_all_validate
(
replacements
,
reason
=
reason
)
node
=
inplace_node
replaced
=
True
except
InconsistencyError
:
...
...
@@ -278,6 +248,50 @@ class InplaceElemwiseOptimizer(GraphRewriter):
)
class
InplaceElemwiseOptimizer
(
InplaceGraphOptimizer
):
op
=
Elemwise
def
filter_candidate_pairs
(
self
,
fgraph
,
node
,
protected_inputs
):
candidate_inputs
=
[
(
node
.
inputs
.
index
(
inp
),
inp
)
for
inp
in
inplace_candidates
(
fgraph
,
node
.
inputs
,
protected_inputs
=
protected_inputs
,
)
]
if
not
candidate_inputs
:
return
[]
return
[
((
o
,
out
),
(
i
,
inp
))
for
o
,
out
in
enumerate
(
node
.
outputs
)
for
i
,
inp
in
candidate_inputs
if
inp
.
type
==
out
.
type
]
def
create_inplace_node
(
self
,
node
,
inplace_pattern
):
op
=
node
.
op
scalar_op
=
op
.
scalar_op
inplace_pattern
=
{
i
:
o
for
i
,
[
o
]
in
inplace_pattern
.
items
()}
if
hasattr
(
scalar_op
,
"make_new_inplace"
):
new_scalar_op
=
scalar_op
.
make_new_inplace
(
ps
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
o
.
dtype
)
for
i
,
o
in
enumerate
(
node
.
outputs
)
]
)
)
else
:
new_scalar_op
=
type
(
scalar_op
)(
ps
.
transfer_type
(
*
[
inplace_pattern
.
get
(
i
,
None
)
for
i
in
range
(
len
(
node
.
outputs
))]
)
)
return
type
(
op
)(
new_scalar_op
,
inplace_pattern
)
.
make_node
(
*
node
.
inputs
)
compile
.
optdb
.
register
(
"inplace_elemwise"
,
InplaceElemwiseOptimizer
(),
...
...
tests/tensor/test_blockwise.py
浏览文件 @
1ebd078a
...
...
@@ -8,11 +8,21 @@ import scipy.linalg
import
pytensor
from
pytensor
import
In
,
config
,
function
,
scan
from
pytensor.compile
import
get_default_mode
,
get_mode
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.gradient
import
grad
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph
import
Apply
,
FunctionGraph
,
Op
,
rewrite_graph
from
pytensor.graph.replace
import
vectorize_graph
,
vectorize_node
from
pytensor.raise_op
import
assert_op
from
pytensor.tensor
import
diagonal
,
dmatrix
,
log
,
ones_like
,
scalar
,
tensor
,
vector
from
pytensor.tensor
import
(
diagonal
,
dmatrix
,
log
,
matrices
,
ones_like
,
scalar
,
tensor
,
vector
,
)
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
...
...
@@ -698,3 +708,57 @@ def test_scan_gradient_core_type():
grad_sit_sot0
.
eval
({
vec_seq
:
np
.
ones
((
4
,
n_steps
,
1
))}),
np
.
ones
((
4
,
n_steps
,
1
)),
)
def
test_partial_inplace
():
class
CoreOp
(
Op
):
__props__
=
(
"inplace"
,)
def
__init__
(
self
,
inplace
):
self
.
inplace
=
tuple
(
inplace
)
self
.
destroy_map
=
{
i
:
[
i
]
for
i
in
inplace
}
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
):
return
type
(
self
)(
inplace
=
allowed_inplace_inputs
)
def
make_node
(
self
,
x
,
y
,
z
):
return
Apply
(
self
,
[
x
,
y
,
z
],
[
x
.
type
(),
y
.
type
(),
z
.
type
()])
def
perform
(
self
,
node
,
inputs
,
outputs
):
[
x
,
y
,
z
]
=
inputs
if
0
not
in
self
.
inplace
:
x
=
x
.
copy
()
if
1
not
in
self
.
inplace
:
y
=
y
.
copy
()
if
2
not
in
self
.
inplace
:
z
=
z
.
copy
()
outputs
[
0
][
0
]
=
x
outputs
[
1
][
0
]
=
y
outputs
[
2
][
0
]
=
z
core_op
=
CoreOp
(
inplace
=
())
blockwise_op
=
Blockwise
(
core_op
,
signature
=
"(),(),()->(),(),()"
)
x
,
y
,
z
=
matrices
(
"xyz"
)
# All can be inplaced
out
=
blockwise_op
(
x
.
T
,
y
.
T
,
z
.
T
)
fgraph
=
FunctionGraph
([
x
,
y
,
z
],
out
)
add_supervisor_to_fgraph
(
fgraph
,
[
In
(
inp
,
mutable
=
True
)
for
inp
in
fgraph
.
inputs
])
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
assert
fgraph
.
outputs
[
0
]
.
owner
.
op
.
destroy_map
==
{
0
:
[
0
],
1
:
[
1
],
2
:
[
2
]}
# Only x, z can be inplaced, y is protected
out
=
blockwise_op
(
x
.
T
,
y
.
T
,
z
.
T
)
fgraph
=
FunctionGraph
([
x
,
y
,
z
],
out
)
add_supervisor_to_fgraph
(
fgraph
,
[
In
(
inp
,
mutable
=
(
i
%
2
)
==
0
)
for
i
,
inp
in
enumerate
(
fgraph
.
inputs
)]
)
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
assert
fgraph
.
outputs
[
0
]
.
owner
.
op
.
destroy_map
==
{
0
:
[
0
],
2
:
[
2
]}
# Only y can be inplaced, x is reused for first and third outputs
out
=
blockwise_op
(
x
.
T
,
y
.
T
,
x
.
T
)
fgraph
=
FunctionGraph
([
x
,
y
,
z
],
out
)
add_supervisor_to_fgraph
(
fgraph
,
[
In
(
inp
,
mutable
=
True
)
for
inp
in
fgraph
.
inputs
])
rewrite_graph
(
fgraph
,
include
=
(
"inplace"
,))
assert
fgraph
.
outputs
[
0
]
.
owner
.
op
.
destroy_map
==
{
1
:
[
1
]}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论