Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
69efc68b
提交
69efc68b
authored
2月 25, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Handle inplace rewrites correctly in dispatch of OpFromGraph and Scan
JAX needs no special handling because it excludes inplace rewrites.
上级
ad1af2ea
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
131 行增加
和
45 行删除
+131
-45
types.py
pytensor/compile/function/types.py
+56
-18
mode.py
pytensor/compile/mode.py
+5
-1
scan.py
pytensor/link/jax/dispatch/scan.py
+4
-2
basic.py
pytensor/link/numba/dispatch/basic.py
+9
-2
scan.py
pytensor/link/numba/dispatch/scan.py
+12
-3
basic.py
pytensor/link/pytorch/dispatch/basic.py
+9
-0
op.py
pytensor/scan/op.py
+3
-19
test_basic.py
tests/link/numba/test_basic.py
+33
-0
没有找到文件。
pytensor/compile/function/types.py
浏览文件 @
69efc68b
...
@@ -5,6 +5,7 @@ import copyreg
...
@@ -5,6 +5,7 @@ import copyreg
import
logging
import
logging
import
time
import
time
import
warnings
import
warnings
from
collections.abc
import
Sequence
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
...
@@ -168,6 +169,59 @@ class Supervisor(Feature):
...
@@ -168,6 +169,59 @@ class Supervisor(Feature):
raise
InconsistencyError
(
f
"Trying to destroy a protected variable: {r}"
)
raise
InconsistencyError
(
f
"Trying to destroy a protected variable: {r}"
)
def
add_supervisor_to_fgraph
(
fgraph
:
FunctionGraph
,
input_specs
:
Sequence
[
SymbolicInput
],
accept_inplace
:
bool
=
False
,
)
->
None
:
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used.
Parameters
----------
fgraph: FunctionGraph
The FunctionGraph to setup the Supervisor Feature in.
input_specs: Sequence of SymbolicInput
The input specifications for the FunctionGraph.
Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation
(if `accept_inplace` is True) will be protected from inplace operations.
Otherwise, they will be allowed to be destroyed.
accept_inplace: bool
Whether to allow inplace operations to already be present in the graph.
Raises
------
TypeError
If inplace operations are not allowed and the graph already contains inplace operations.
"""
has_destroy_handler
=
hasattr
(
fgraph
,
"destroyers"
)
if
not
(
has_destroy_handler
and
accept_inplace
):
# Check if fgraph already contains destructive operations,
# in which case we need to add a DestroyHandler or raise an error
for
node
in
fgraph
.
apply_nodes
:
if
node
.
op
.
destroy_map
:
if
not
accept_inplace
:
raise
TypeError
(
f
"Graph must not contain inplace operations: {node}"
)
else
:
has_destroy_handler
=
True
fgraph
.
attach_feature
(
DestroyHandler
())
break
# Protect all immutable inputs from inplace operations.
fgraph
.
attach_feature
(
Supervisor
(
input
for
spec
,
input
in
zip
(
input_specs
,
fgraph
.
inputs
,
strict
=
True
)
if
not
(
spec
.
mutable
or
has_destroy_handler
and
fgraph
.
has_destroyers
([
input
])
)
)
)
def
std_fgraph
(
def
std_fgraph
(
input_specs
:
list
[
SymbolicInput
],
input_specs
:
list
[
SymbolicInput
],
output_specs
:
list
[
SymbolicOutput
],
output_specs
:
list
[
SymbolicOutput
],
...
@@ -229,24 +283,8 @@ def std_fgraph(
...
@@ -229,24 +283,8 @@ def std_fgraph(
found_updates
.
extend
(
map
(
SymbolicOutput
,
updates
))
found_updates
.
extend
(
map
(
SymbolicOutput
,
updates
))
for
node
in
fgraph
.
apply_nodes
:
add_supervisor_to_fgraph
(
if
node
.
op
.
destroy_map
:
fgraph
=
fgraph
,
input_specs
=
input_specs
,
accept_inplace
=
accept_inplace
if
not
accept_inplace
:
raise
TypeError
(
f
"Graph must not contain inplace operations: {node}"
)
else
:
fgraph
.
attach_feature
(
DestroyHandler
())
break
# We need to protect all immutable inputs from inplace operations.
fgraph
.
attach_feature
(
Supervisor
(
input
for
spec
,
input
in
zip
(
input_specs
,
fgraph
.
inputs
,
strict
=
True
)
if
not
(
spec
.
mutable
or
(
hasattr
(
fgraph
,
"destroyers"
)
and
fgraph
.
has_destroyers
([
input
]))
)
)
)
)
# If named nodes are replaced, keep the name
# If named nodes are replaced, keep the name
...
...
pytensor/compile/mode.py
浏览文件 @
69efc68b
...
@@ -138,7 +138,11 @@ class AddDestroyHandler(GraphRewriter):
...
@@ -138,7 +138,11 @@ class AddDestroyHandler(GraphRewriter):
break
break
if
not
supervisor_added
:
if
not
supervisor_added
:
warnings
.
warn
(
warnings
.
warn
(
f
"A Supervisor feature is missing from {fgraph}."
,
(
f
"A Supervisor feature is missing from {fgraph}.
\n
"
"This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.
\n
"
"A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`."
),
stacklevel
=
3
,
stacklevel
=
3
,
)
)
...
...
pytensor/link/jax/dispatch/scan.py
浏览文件 @
69efc68b
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
pytensor.compile.mode
import
JAX
from
pytensor.compile.mode
import
JAX
,
get_mode
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
...
@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
...
@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
)
)
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
rewriter
=
op
.
mode_instance
.
excluding
(
*
JAX
.
_optimizer
.
exclude
)
.
optimizer
rewriter
=
(
get_mode
(
op
.
mode
)
.
including
(
"jax"
)
.
excluding
(
*
JAX
.
_optimizer
.
exclude
)
.
optimizer
)
rewriter
(
op
.
fgraph
)
rewriter
(
op
.
fgraph
)
scan_inner_func
=
jax_funcify
(
op
.
fgraph
,
**
kwargs
)
scan_inner_func
=
jax_funcify
(
op
.
fgraph
,
**
kwargs
)
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
69efc68b
...
@@ -16,9 +16,10 @@ from numba.core.errors import NumbaWarning, TypingError
...
@@ -16,9 +16,10 @@ from numba.core.errors import NumbaWarning, TypingError
from
numba.cpython.unsafe.tuple
import
tuple_setitem
# noqa: F401
from
numba.cpython.unsafe.tuple
import
tuple_setitem
# noqa: F401
from
numba.extending
import
box
,
overload
from
numba.extending
import
box
,
overload
from
pytensor
import
config
from
pytensor
import
In
,
config
from
pytensor.compile
import
NUMBA
from
pytensor.compile
import
NUMBA
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
...
@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
...
@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
# The C-code defers it to the make_thunk phase
NUMBA
.
optimizer
(
op
.
fgraph
)
fgraph
=
op
.
fgraph
add_supervisor_to_fgraph
(
fgraph
=
fgraph
,
input_specs
=
[
In
(
x
,
borrow
=
True
,
mutable
=
False
)
for
x
in
fgraph
.
inputs
],
accept_inplace
=
True
,
)
NUMBA
.
optimizer
(
fgraph
)
fgraph_fn
=
numba_njit
(
numba_funcify
(
op
.
fgraph
,
**
kwargs
))
fgraph_fn
=
numba_njit
(
numba_funcify
(
op
.
fgraph
,
**
kwargs
))
if
len
(
op
.
fgraph
.
outputs
)
==
1
:
if
len
(
op
.
fgraph
.
outputs
)
==
1
:
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
69efc68b
...
@@ -4,7 +4,9 @@ import numpy as np
...
@@ -4,7 +4,9 @@ import numpy as np
from
numba
import
types
from
numba
import
types
from
numba.extending
import
overload
from
numba.extending
import
overload
from
pytensor.compile.mode
import
NUMBA
from
pytensor
import
In
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.mode
import
NUMBA
,
get_mode
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
create_arg_string
,
create_arg_string
,
...
@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
...
@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
# explicitly triggers the optimization of the inner graphs of Scan?
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
# The C-code defers it to the make_thunk phase
rewriter
=
(
rewriter
=
(
op
.
mode_instance
.
including
(
"numba"
)
get_mode
(
op
.
mode
)
.
including
(
"numba"
)
.
excluding
(
*
NUMBA
.
_optimizer
.
exclude
)
.
excluding
(
*
NUMBA
.
_optimizer
.
exclude
)
.
optimizer
.
optimizer
)
)
rewriter
(
op
.
fgraph
)
fgraph
=
op
.
fgraph
add_supervisor_to_fgraph
(
fgraph
=
fgraph
,
input_specs
=
[
In
(
x
,
borrow
=
True
,
mutable
=
False
)
for
x
in
fgraph
.
inputs
],
accept_inplace
=
True
,
)
rewriter
(
fgraph
)
scan_inner_func
=
numba_basic
.
numba_njit
(
numba_funcify
(
op
.
fgraph
))
scan_inner_func
=
numba_basic
.
numba_njit
(
numba_funcify
(
op
.
fgraph
))
...
...
pytensor/link/pytorch/dispatch/basic.py
浏览文件 @
69efc68b
...
@@ -5,8 +5,10 @@ import numpy as np
...
@@ -5,8 +5,10 @@ import numpy as np
import
torch
import
torch
import
torch.compiler
import
torch.compiler
from
pytensor
import
In
from
pytensor.compile
import
PYTORCH
from
pytensor.compile
import
PYTORCH
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.compile.ops
import
DeepCopyOp
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
...
@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
...
@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs
.
pop
(
"storage_map"
,
None
)
kwargs
.
pop
(
"storage_map"
,
None
)
# Apply inner rewrites
# Apply inner rewrites
PYTORCH
.
optimizer
(
op
.
fgraph
)
PYTORCH
.
optimizer
(
op
.
fgraph
)
fgraph
=
op
.
fgraph
add_supervisor_to_fgraph
(
fgraph
=
fgraph
,
input_specs
=
[
In
(
x
,
borrow
=
True
,
mutable
=
False
)
for
x
in
fgraph
.
inputs
],
accept_inplace
=
True
,
)
PYTORCH
.
optimizer
(
fgraph
)
fgraph_fn
=
pytorch_funcify
(
op
.
fgraph
,
**
kwargs
,
squeeze_output
=
True
)
fgraph_fn
=
pytorch_funcify
(
op
.
fgraph
,
**
kwargs
,
squeeze_output
=
True
)
return
fgraph_fn
return
fgraph_fn
...
...
pytensor/scan/op.py
浏览文件 @
69efc68b
...
@@ -57,6 +57,7 @@ import pytensor.link.utils as link_utils
...
@@ -57,6 +57,7 @@ import pytensor.link.utils as link_utils
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.compile.builders
import
construct_nominal_fgraph
,
infer_shape
from
pytensor.compile.builders
import
construct_nominal_fgraph
,
infer_shape
from
pytensor.compile.function.pfunc
import
pfunc
from
pytensor.compile.function.pfunc
import
pfunc
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.io
import
In
,
Out
from
pytensor.compile.io
import
In
,
Out
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.compile.profiling
import
register_profiler_printer
from
pytensor.compile.profiling
import
register_profiler_printer
...
@@ -834,8 +835,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -834,8 +835,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self
.
n_outer_inputs
=
info
.
n_outer_inputs
self
.
n_outer_inputs
=
info
.
n_outer_inputs
self
.
n_outer_outputs
=
info
.
n_outer_outputs
self
.
n_outer_outputs
=
info
.
n_outer_outputs
_
=
self
.
prepare_fgraph
(
self
.
fgraph
)
if
any
(
node
.
op
.
destroy_map
for
node
in
self
.
fgraph
.
apply_nodes
):
if
any
(
node
.
op
.
destroy_map
for
node
in
self
.
fgraph
.
apply_nodes
):
raise
InconsistencyError
(
raise
InconsistencyError
(
"Inner-graphs must not contain in-place operations."
"Inner-graphs must not contain in-place operations."
...
@@ -1394,23 +1393,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1394,23 +1393,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
fgraph
.
update_mapping
=
update_mapping
fgraph
.
update_mapping
=
update_mapping
from
pytensor.compile.function.types
import
Supervisor
add_supervisor_to_fgraph
(
from
pytensor.graph.destroyhandler
import
DestroyHandler
fgraph
=
fgraph
,
input_specs
=
wrapped_inputs
,
accept_inplace
=
True
for
node
in
fgraph
.
apply_nodes
:
if
node
.
op
.
destroy_map
:
fgraph
.
attach_feature
(
DestroyHandler
())
break
fgraph
.
attach_feature
(
Supervisor
(
inp
for
spec
,
inp
in
zip
(
wrapped_inputs
,
fgraph
.
inputs
,
strict
=
True
)
if
not
(
getattr
(
spec
,
"mutable"
,
None
)
or
(
hasattr
(
fgraph
,
"destroyers"
)
and
fgraph
.
has_destroyers
([
inp
]))
)
)
)
)
return
wrapped_inputs
,
wrapped_outputs
return
wrapped_inputs
,
wrapped_outputs
...
...
tests/link/numba/test_basic.py
浏览文件 @
69efc68b
...
@@ -835,6 +835,39 @@ def test_OpFromGraph():
...
@@ -835,6 +835,39 @@ def test_OpFromGraph():
compare_numba_and_py
([
x
,
y
,
z
],
[
out
],
[
xv
,
yv
,
zv
])
compare_numba_and_py
([
x
,
y
,
z
],
[
out
],
[
xv
,
yv
,
zv
])
@pytest.mark.filterwarnings
(
"error"
)
def
test_ofg_inner_inplace
():
x
=
pt
.
vector
(
"x"
)
set0
=
x
[
0
]
.
set
(
1
)
# SetSubtensor should not inplace on x
exp_x
=
pt
.
exp
(
x
)
set1
=
exp_x
[
0
]
.
set
(
1
)
# SetSubtensor should inplace on exp_x
ofg0
=
OpFromGraph
([
x
],
[
set0
])
ofg1
=
OpFromGraph
([
x
],
[
set1
])
y
,
z
=
pt
.
vectors
(
"y"
,
"z"
)
fn
=
function
([
y
,
z
],
[
ofg0
(
y
),
ofg1
(
z
)],
mode
=
"NUMBA"
)
fn_ofg0
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
fn_ofg0
,
OpFromGraph
)
fn_set0
=
fn_ofg0
.
fgraph
.
outputs
[
0
]
assert
fn_set0
.
owner
.
op
.
destroy_map
==
{}
fn_ofg1
=
fn
.
maker
.
fgraph
.
outputs
[
1
]
.
owner
.
op
assert
isinstance
(
fn_ofg1
,
OpFromGraph
)
fn_set1
=
fn_ofg1
.
fgraph
.
outputs
[
0
]
assert
fn_set1
.
owner
.
op
.
destroy_map
==
{
0
:
[
0
]}
x_test
=
np
.
array
([
0
,
1
,
1
],
dtype
=
config
.
floatX
)
y_test
=
np
.
array
([
0
,
1
,
1
],
dtype
=
config
.
floatX
)
res0
,
res1
=
fn
(
x_test
,
y_test
)
# Check inputs were not mutated
np
.
testing
.
assert_allclose
(
x_test
,
[
0
,
1
,
1
])
np
.
testing
.
assert_allclose
(
y_test
,
[
0
,
1
,
1
])
# Check outputs are correct
np
.
testing
.
assert_allclose
(
res0
,
[
1
,
1
,
1
])
np
.
testing
.
assert_allclose
(
res1
,
[
1
,
np
.
e
,
np
.
e
])
@pytest.mark.filterwarnings
(
"error"
)
@pytest.mark.filterwarnings
(
"error"
)
def
test_cache_warning_suppressed
():
def
test_cache_warning_suppressed
():
x
=
pt
.
vector
(
"x"
,
shape
=
(
5
,),
dtype
=
"float64"
)
x
=
pt
.
vector
(
"x"
,
shape
=
(
5
,),
dtype
=
"float64"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论