Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
42e8490c
提交
42e8490c
authored
11月 25, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support for TypedList in numba backend
Note: Numba object mode fallback is not safe with lists
上级
ac51f01c
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
273 行增加
和
1 行删除
+273
-1
__init__.py
pytensor/link/numba/dispatch/__init__.py
+1
-0
basic.py
pytensor/link/numba/dispatch/basic.py
+7
-1
typed_list.py
pytensor/link/numba/dispatch/typed_list.py
+219
-0
test_typed_list.py
tests/link/numba/test_typed_list.py
+46
-0
没有找到文件。
pytensor/link/numba/dispatch/__init__.py
浏览文件 @
42e8490c
...
@@ -17,6 +17,7 @@ import pytensor.link.numba.dispatch.sort
...
@@ -17,6 +17,7 @@ import pytensor.link.numba.dispatch.sort
import
pytensor.link.numba.dispatch.sparse
import
pytensor.link.numba.dispatch.sparse
import
pytensor.link.numba.dispatch.subtensor
import
pytensor.link.numba.dispatch.subtensor
import
pytensor.link.numba.dispatch.tensor_basic
import
pytensor.link.numba.dispatch.tensor_basic
import
pytensor.link.numba.dispatch.typed_list
# isort: on
# isort: on
pytensor/link/numba/dispatch/basic.py
浏览文件 @
42e8490c
...
@@ -23,6 +23,7 @@ from pytensor.sparse import SparseTensorType
...
@@ -23,6 +23,7 @@ from pytensor.sparse import SparseTensorType
from
pytensor.tensor.random.type
import
RandomGeneratorType
from
pytensor.tensor.random.type
import
RandomGeneratorType
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.utils
import
hash_from_ndarray
from
pytensor.tensor.utils
import
hash_from_ndarray
from
pytensor.typed_list
import
TypedListType
def
_filter_numba_warnings
():
def
_filter_numba_warnings
():
...
@@ -132,6 +133,8 @@ def get_numba_type(
...
@@ -132,6 +133,8 @@ def get_numba_type(
return
CSCMatrixType
(
numba_dtype
)
return
CSCMatrixType
(
numba_dtype
)
elif
isinstance
(
pytensor_type
,
RandomGeneratorType
):
elif
isinstance
(
pytensor_type
,
RandomGeneratorType
):
return
numba
.
types
.
NumPyRandomGeneratorType
(
"NumPyRandomGeneratorType"
)
return
numba
.
types
.
NumPyRandomGeneratorType
(
"NumPyRandomGeneratorType"
)
elif
isinstance
(
pytensor_type
,
TypedListType
):
return
numba
.
types
.
List
(
get_numba_type
(
pytensor_type
.
ttype
))
else
:
else
:
raise
NotImplementedError
(
f
"Numba type not implemented for {pytensor_type}"
)
raise
NotImplementedError
(
f
"Numba type not implemented for {pytensor_type}"
)
...
@@ -260,7 +263,10 @@ def numba_typify(data, dtype=None, **kwargs):
...
@@ -260,7 +263,10 @@ def numba_typify(data, dtype=None, **kwargs):
def
generate_fallback_impl
(
op
,
node
,
storage_map
=
None
,
**
kwargs
):
def
generate_fallback_impl
(
op
,
node
,
storage_map
=
None
,
**
kwargs
):
"""Create a Numba compatible function from a Pytensor `Op`."""
"""Create a Numba compatible function from a Pytensor `Op`.
Note limitations: https://numba.pydata.org/numba-doc/dev/user/withobjmode.html#the-objmode-context-manager
"""
warnings
.
warn
(
warnings
.
warn
(
f
"Numba will use object mode to run {op}'s perform method. "
f
"Numba will use object mode to run {op}'s perform method. "
...
...
pytensor/link/numba/dispatch/typed_list.py
0 → 100644
浏览文件 @
42e8490c
import
numba
import
numpy
as
np
import
pytensor.link.numba.dispatch.basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
register_funcify_default_op_cache_key
from
pytensor.link.numba.dispatch.compile_ops
import
numba_deepcopy
from
pytensor.tensor.type_other
import
SliceType
from
pytensor.typed_list
import
(
Append
,
Count
,
Extend
,
GetItem
,
Index
,
Insert
,
Length
,
MakeList
,
Remove
,
Reverse
,
)
def
numba_all_equal
(
x
,
y
):
if
isinstance
(
x
,
np
.
ndarray
)
or
isinstance
(
y
,
np
.
ndarray
):
if
not
(
isinstance
(
x
,
np
.
ndarray
)
and
isinstance
(
y
,
np
.
ndarray
)):
return
False
return
(
x
==
y
)
.
all
()
if
isinstance
(
x
,
list
)
or
isinstance
(
y
,
list
):
if
not
(
isinstance
(
x
,
list
)
and
isinstance
(
y
,
list
)):
return
False
if
len
(
x
)
!=
len
(
y
):
return
False
return
all
(
numba_all_equal
(
xi
,
yi
)
for
xi
,
yi
in
zip
(
x
,
y
))
return
x
==
y
@numba.extending.overload
(
numba_all_equal
)
def
list_all_equal
(
x
,
y
):
all_equal
=
None
if
isinstance
(
x
,
numba
.
types
.
List
)
and
isinstance
(
y
,
numba
.
types
.
List
):
def
all_equal
(
x
,
y
):
if
len
(
x
)
!=
len
(
y
):
return
False
for
xi
,
yi
in
zip
(
x
,
y
):
if
not
numba_all_equal
(
xi
,
yi
):
return
False
return
True
if
isinstance
(
x
,
numba
.
types
.
Array
)
and
isinstance
(
y
,
numba
.
types
.
Array
):
def
all_equal
(
x
,
y
):
return
(
x
==
y
)
.
all
()
if
isinstance
(
x
,
numba
.
types
.
Number
)
and
isinstance
(
y
.
numba
.
types
.
Number
):
def
all_equal
(
x
,
y
):
return
x
==
y
return
all_equal
@numba.extending.overload
(
numba_deepcopy
)
def
numba_deepcopy_list
(
x
):
if
isinstance
(
x
,
numba
.
types
.
List
):
def
deepcopy_list
(
x
):
return
[
numba_deepcopy
(
xi
)
for
xi
in
x
]
return
deepcopy_list
@register_funcify_default_op_cache_key
(
MakeList
)
def
numba_funcify_make_list
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
make_list
(
*
args
):
return
[
numba_deepcopy
(
arg
)
for
arg
in
args
]
return
make_list
@register_funcify_default_op_cache_key
(
Length
)
def
numba_funcify_list_length
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
list_length
(
x
):
return
np
.
array
(
len
(
x
),
dtype
=
np
.
int64
)
return
list_length
@register_funcify_default_op_cache_key
(
GetItem
)
def
numba_funcify_list_get_item
(
op
,
node
,
**
kwargs
):
if
isinstance
(
node
.
inputs
[
1
]
.
type
,
SliceType
):
@numba_basic.numba_njit
def
list_get_item_slice
(
x
,
index
):
return
x
[
index
]
return
list_get_item_slice
else
:
@numba_basic.numba_njit
def
list_get_item_index
(
x
,
index
):
return
x
[
index
.
item
()]
return
list_get_item_index
@register_funcify_default_op_cache_key
(
Reverse
)
def
numba_funcify_list_reverse
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
@numba_basic.numba_njit
def
list_reverse
(
x
):
if
inplace
:
z
=
x
else
:
z
=
numba_deepcopy
(
x
)
z
.
reverse
()
return
z
return
list_reverse
@register_funcify_default_op_cache_key
(
Append
)
def
numba_funcify_list_append
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
@numba_basic.numba_njit
def
list_append
(
x
,
to_append
):
if
inplace
:
z
=
x
else
:
z
=
numba_deepcopy
(
x
)
z
.
append
(
numba_deepcopy
(
to_append
))
return
z
return
list_append
@register_funcify_default_op_cache_key
(
Extend
)
def
numba_funcify_list_extend
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
@numba_basic.numba_njit
def
list_extend
(
x
,
to_append
):
if
inplace
:
z
=
x
else
:
z
=
numba_deepcopy
(
x
)
z
.
extend
(
numba_deepcopy
(
to_append
))
return
z
return
list_extend
@register_funcify_default_op_cache_key
(
Insert
)
def
numba_funcify_list_insert
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
@numba_basic.numba_njit
def
list_insert
(
x
,
index
,
to_insert
):
if
inplace
:
z
=
x
else
:
z
=
numba_deepcopy
(
x
)
z
.
insert
(
index
.
item
(),
numba_deepcopy
(
to_insert
))
return
z
return
list_insert
@register_funcify_default_op_cache_key
(
Index
)
def
numba_funcify_list_index
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
list_index
(
x
,
elem
):
for
idx
,
xi
in
enumerate
(
x
):
if
numba_all_equal
(
xi
,
elem
):
break
return
np
.
array
(
idx
,
dtype
=
np
.
int64
)
return
list_index
@register_funcify_default_op_cache_key
(
Count
)
def
numba_funcify_list_count
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
list_count
(
x
,
elem
):
c
=
0
for
xi
in
x
:
if
numba_all_equal
(
xi
,
elem
):
c
+=
1
return
np
.
array
(
c
,
dtype
=
np
.
int64
)
return
list_count
@register_funcify_default_op_cache_key
(
Remove
)
def
numba_funcify_list_remove
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
@numba_basic.numba_njit
def
list_remove
(
x
,
to_remove
):
if
inplace
:
z
=
x
else
:
z
=
numba_deepcopy
(
x
)
index_to_remove
=
-
1
for
i
,
zi
in
enumerate
(
z
):
if
numba_all_equal
(
zi
,
to_remove
):
index_to_remove
=
i
break
if
index_to_remove
==
-
1
:
raise
ValueError
(
"list.remove(x): x not in list"
)
z
.
pop
(
index_to_remove
)
return
z
return
list_remove
tests/link/numba/test_typed_list.py
0 → 100644
浏览文件 @
42e8490c
import
numpy
as
np
from
pytensor.tensor
import
matrix
from
pytensor.typed_list
import
make_list
from
tests.link.numba.test_basic
import
compare_numba_and_py
def
test_list_basic_ops
():
x
=
matrix
(
"x"
,
shape
=
(
3
,
None
),
dtype
=
"int64"
)
l
=
make_list
([
x
[
0
],
x
[
2
]])
x_test
=
np
.
arange
(
12
)
.
reshape
(
3
,
4
)
compare_numba_and_py
([
x
],
[
l
,
l
.
length
()],
[
x_test
])
# Test nested list
ll
=
make_list
([
l
,
l
,
l
])
compare_numba_and_py
([
x
],
[
ll
,
ll
.
length
()],
[
x_test
])
def
test_make_list_index_ops
():
x
=
matrix
(
"x"
,
shape
=
(
3
,
None
),
dtype
=
"int64"
)
l
=
make_list
([
x
[
0
],
x
[
2
]])
x_test
=
np
.
arange
(
12
)
.
reshape
(
3
,
4
)
compare_numba_and_py
([
x
],
[
l
[
-
1
],
l
[:
-
1
],
l
.
reverse
()],
[
x_test
])
def
test_make_list_extend_ops
():
x
=
matrix
(
"x"
,
shape
=
(
3
,
None
),
dtype
=
"int64"
)
l
=
make_list
([
x
[
0
],
x
[
2
]])
x_test
=
np
.
arange
(
12
)
.
reshape
(
3
,
4
)
compare_numba_and_py
(
[
x
],
[
l
.
append
(
x
[
1
]),
l
.
extend
(
l
),
l
.
insert
(
0
,
x
[
1
])],
[
x_test
]
)
def
test_make_list_find_ops
():
# Remove requires to first find it
x
=
matrix
(
"x"
,
shape
=
(
3
,
None
),
dtype
=
"int64"
)
y
=
x
[
0
]
.
type
(
"y"
)
l
=
make_list
([
x
[
0
],
x
[
2
],
x
[
0
],
x
[
2
]])
x_test
=
np
.
arange
(
12
)
.
reshape
(
3
,
4
)
test_y
=
x_test
[
2
]
compare_numba_and_py
([
x
,
y
],
[
l
.
ind
(
y
),
l
.
count
(
y
),
l
.
remove
(
y
)],
[
x_test
,
test_y
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论