Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ba4cca3f
提交
ba4cca3f
authored
11月 19, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 25, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix numba FunctionGraph cache key
It's necessary to encode the edge information, not only the nodes and their ordering
上级
bf60f22f
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
122 行增加
和
31 行删除
+122
-31
basic.py
pytensor/link/numba/dispatch/basic.py
+32
-19
utils.py
pytensor/link/utils.py
+7
-8
test_basic.py
tests/link/numba/test_basic.py
+83
-4
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
ba4cca3f
...
...
@@ -9,7 +9,7 @@ from numba import njit as _njit
from
numba.cpython.unsafe.tuple
import
tuple_setitem
# noqa: F401
from
pytensor
import
config
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.type
import
Type
from
pytensor.link.numba.cache
import
compile_numba_function_src
,
hash_from_pickle_dump
...
...
@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph(
):
# Collect cache keys of every Op/Constant in the FunctionGraph
# so we can create a global cache key for the whole FunctionGraph
fgraph_can_be_cached
=
[
True
]
cache_keys
=
[]
toposort
=
fgraph
.
toposort
()
clients
=
fgraph
.
clients
toposort_indices
=
{
node
:
i
for
i
,
node
in
enumerate
(
toposort
)}
# Add dummy output clients which are not included of the toposort
toposort_indices
|=
{
clients
[
out
][
0
][
0
]:
i
for
i
,
out
in
enumerate
(
fgraph
.
outputs
,
start
=
len
(
toposort
))
toposort_coords
:
dict
[
Variable
,
tuple
[
int
,
int
|
str
]]
=
{
inp
:
(
0
,
i
)
for
i
,
inp
in
enumerate
(
fgraph
.
inputs
)
}
toposort_coords
|=
{
out
:
(
i
,
j
)
for
i
,
node
in
enumerate
(
toposort
,
start
=
1
)
for
j
,
out
in
enumerate
(
node
.
outputs
)
}
def
op_conversion_and_key_collection
(
*
args
,
**
kwargs
):
def
op_conversion_and_key_collection
(
op
,
*
args
,
node
,
**
kwargs
):
# Convert an Op to a funcified function and store the cache_key
# We also Cache each Op so Numba can do less work next time it sees it
func
,
key
=
numba_funcify_ensure_cache
(
*
args
,
**
kwargs
)
cache_keys
.
append
(
key
)
func
,
key
=
numba_funcify_ensure_cache
(
op
,
node
=
node
,
*
args
,
**
kwargs
)
if
key
is
None
:
fgraph_can_be_cached
[
0
]
=
False
else
:
# Add graph coordinate information (input edges and node location)
cache_keys
.
append
(
(
tuple
(
toposort_coords
[
inp
]
for
inp
in
node
.
inputs
),
key
,
)
)
return
func
def
type_conversion_and_key_collection
(
value
,
variable
,
**
kwargs
):
# Convert a constant type to a numba compatible one and compute a cache key for it
# We need to know where in the graph the constants are used
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
# FIXME: It doesn't make sense to call type_conversion on non-constants,
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
# but that's what fgraph_to_python currently does.
# We appease it, but don't consider for caching
if
isinstance
(
variable
,
Constant
):
client_indices
=
tuple
(
(
toposort_indices
[
node
],
inp_idx
)
for
node
,
inp_idx
in
clients
[
variable
]
)
cache_keys
.
append
((
client_indices
,
cache_key_for_constant
(
value
))
)
# Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant
constant_cache_key
=
cache_key_for_constant
(
value
)
assert
constant_cache_key
is
not
None
toposort_coords
[
variable
]
=
(
-
1
,
constant_cache_key
)
return
numba_typify
(
value
,
variable
=
variable
,
**
kwargs
)
py_func
=
fgraph_to_python
(
...
...
@@ -537,12 +547,15 @@ def numba_funcify_FunctionGraph(
fgraph_name
=
fgraph_name
,
**
kwargs
,
)
if
any
(
key
is
None
for
key
in
cache_keys
)
:
if
not
fgraph_can_be_cached
[
0
]
:
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
fgraph_key
=
None
else
:
# Add graph coordinate information for fgraph outputs
fgraph_output_ancestors
=
tuple
(
toposort_coords
[
out
]
for
out
in
fgraph
.
outputs
)
# Compose individual cache_keys into a global key for the FunctionGraph
fgraph_key
=
sha256
(
f
"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {
len(fgraph.outputs)
})"
.
encode
()
f
"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {
fgraph_output_ancestors
})"
.
encode
()
)
.
hexdigest
()
return
numba_njit
(
py_func
),
fgraph_key
pytensor/link/utils.py
浏览文件 @
ba4cca3f
...
...
@@ -735,14 +735,6 @@ def fgraph_to_python(
body_assigns
=
[]
for
node
in
order
:
compiled_func
=
op_conversion_fn
(
node
.
op
,
node
=
node
,
storage_map
=
storage_map
,
**
kwargs
)
# Create a local alias with a unique name
local_compiled_func_name
=
unique_name
(
compiled_func
)
global_env
[
local_compiled_func_name
]
=
compiled_func
node_input_names
=
[]
for
inp
in
node
.
inputs
:
local_input_name
=
unique_name
(
inp
)
...
...
@@ -772,6 +764,13 @@ def fgraph_to_python(
node_output_names
=
[
unique_name
(
v
)
for
v
in
node
.
outputs
]
compiled_func
=
op_conversion_fn
(
node
.
op
,
node
=
node
,
storage_map
=
storage_map
,
**
kwargs
)
# Create a local alias with a unique name
local_compiled_func_name
=
unique_name
(
compiled_func
)
global_env
[
local_compiled_func_name
]
=
compiled_func
assign_str
=
f
"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
assign_comment_str
=
f
"{indent(str(node), '# ')}"
assign_block_str
=
f
"{assign_comment_str}
\n
{assign_str}"
...
...
tests/link/numba/test_basic.py
浏览文件 @
ba4cca3f
...
...
@@ -7,8 +7,7 @@ import numpy as np
import
pytest
import
scipy
from
pytensor.compile
import
SymbolicInput
from
pytensor.tensor.utils
import
hash_from_ndarray
from
pytensor.tensor
import
scalar_from_tensor
numba
=
pytest
.
importorskip
(
"numba"
)
...
...
@@ -16,17 +15,23 @@ numba = pytest.importorskip("numba")
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
shared
from
pytensor.compile
import
SymbolicInput
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.type
import
Type
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
cache_key_for_constant
from
pytensor.link.numba.dispatch.basic
import
(
cache_key_for_constant
,
numba_funcify_and_cache_key
,
)
from
pytensor.link.numba.linker
import
NumbaLinker
from
pytensor.scalar.basic
import
ScalarOp
,
as_scalar
from
pytensor.scalar.basic
import
Composite
,
ScalarOp
,
as_scalar
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.utils
import
hash_from_ndarray
if
TYPE_CHECKING
:
...
...
@@ -652,3 +657,77 @@ def test_funcify_dispatch_interop():
outs
[
2
]
.
owner
.
op
,
outs
[
2
]
.
owner
)
assert
numba
.
njit
(
lambda
x
:
fn2_def_cached
(
x
))(
test_x
)
==
2
class
TestFgraphCacheKey
:
@staticmethod
def
generate_and_validate_key
(
fg
):
_
,
key
=
numba_funcify_and_cache_key
(
fg
)
assert
key
is
not
None
_
,
key_again
=
numba_funcify_and_cache_key
(
fg
)
assert
key
==
key_again
# Check its stable
return
key
def
test_node_order
(
self
):
x
=
pt
.
scalar
(
"x"
)
log_x
=
pt
.
log
(
x
)
graphs
=
[
pt
.
exp
(
x
)
/
log_x
,
log_x
/
pt
.
exp
(
x
),
pt
.
exp
(
log_x
)
/
x
,
x
/
pt
.
exp
(
log_x
),
pt
.
exp
(
log_x
)
/
log_x
,
log_x
/
pt
.
exp
(
log_x
),
]
keys
=
[]
for
graph
in
graphs
:
fg
=
FunctionGraph
([
x
],
[
graph
],
clone
=
False
)
keys
.
append
(
self
.
generate_and_validate_key
(
fg
))
# Check keys are unique
assert
len
(
set
(
keys
))
==
len
(
graphs
)
# Extra unused input should alter the key, because it changes the function signature
y
=
pt
.
scalar
(
"y"
)
for
inputs
in
[[
x
,
y
],
[
y
,
x
]]:
fg
=
FunctionGraph
(
inputs
,
[
graphs
[
0
]],
clone
=
False
)
keys
.
append
(
self
.
generate_and_validate_key
(
fg
))
assert
len
(
set
(
keys
))
==
len
(
graphs
)
+
2
# Adding an input as an output should also change the key
for
outputs
in
[
[
graphs
[
0
],
x
],
[
x
,
graphs
[
0
]],
[
x
,
x
,
graphs
[
0
]],
[
x
,
graphs
[
0
],
x
],
[
graphs
[
0
],
x
,
x
],
]:
fg
=
FunctionGraph
([
x
],
outputs
,
clone
=
False
)
keys
.
append
(
self
.
generate_and_validate_key
(
fg
))
assert
len
(
set
(
keys
))
==
len
(
graphs
)
+
2
+
5
def
test_multi_output
(
self
):
x
=
pt
.
scalar
(
"x"
)
xs
=
scalar_from_tensor
(
x
)
out0
,
out1
=
Elemwise
(
Composite
([
xs
],
[
xs
*
2
,
xs
-
2
]))(
x
)
test_outs
=
[
[
out0
],
[
out1
],
[
out0
,
out1
],
[
out1
,
out0
],
]
keys
=
[]
for
test_out
in
test_outs
:
fg
=
FunctionGraph
([
x
],
test_out
,
clone
=
False
)
keys
.
append
(
self
.
generate_and_validate_key
(
fg
))
assert
len
(
set
(
keys
))
==
len
(
test_outs
)
def
test_constant_output
(
self
):
fg_pi
=
FunctionGraph
([],
[
pt
.
constant
(
np
.
pi
)])
fg_e
=
FunctionGraph
([],
[
pt
.
constant
(
np
.
e
)])
assert
self
.
generate_and_validate_key
(
fg_pi
)
!=
self
.
generate_and_validate_key
(
fg_e
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论