Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6ac5ab28
提交
6ac5ab28
authored
11月 03, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 16, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cache keys for numba Op dispatches
上级
74ab0383
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
419 行增加
和
240 行删除
+419
-240
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+51
-28
compile_ops.py
pytensor/link/numba/dispatch/compile_ops.py
+30
-9
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+0
-0
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+21
-15
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+13
-13
random.py
pytensor/link/numba/dispatch/random.py
+44
-23
scalar.py
pytensor/link/numba/dispatch/scalar.py
+59
-51
scan.py
pytensor/link/numba/dispatch/scan.py
+20
-6
shape.py
pytensor/link/numba/dispatch/shape.py
+9
-7
conv.py
pytensor/link/numba/dispatch/signal/conv.py
+4
-2
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+8
-6
sort.py
pytensor/link/numba/dispatch/sort.py
+5
-3
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+66
-19
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+36
-24
vectorize_codegen.py
pytensor/link/numba/dispatch/vectorize_codegen.py
+26
-23
test_conv.py
tests/link/numba/signal/test_conv.py
+17
-6
test_basic.py
tests/link/numba/test_basic.py
+10
-5
没有找到文件。
pytensor/link/numba/dispatch/blockwise.py
浏览文件 @
6ac5ab28
import
sys
from
hashlib
import
sha256
from
typing
import
cast
from
numba.core.extending
import
overload
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_vectorized
,
encode_literals
,
store_core_outputs
,
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.tensor
import
TensorVariable
,
get_vector_length
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
@
numba_funcify.register
(
BlockwiseWithCoreShape
)
@
register_funcify_and_cache_key
(
BlockwiseWithCoreShape
)
def
numba_funcify_Blockwise
(
op
:
BlockwiseWithCoreShape
,
node
,
**
kwargs
):
[
blockwise_node
]
=
op
.
fgraph
.
apply_nodes
blockwise_op
:
Blockwise
=
blockwise_node
.
op
...
...
@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
cast
(
tuple
[
TensorVariable
],
node
.
inputs
[:
nin
]),
propagate_unbatched_core_inputs
=
True
,
)
core_op_fn
=
numba_funcif
y
(
core_op_fn
,
core_op_key
=
numba_funcify_and_cache_ke
y
(
core_op
,
node
=
core_node
,
parent_node
=
node
,
...
...
@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src
+=
")"
to_tuple
=
numba_basic
.
numba_njit
(
compile_function_src
(
compile_
numba_
function_src
(
src
,
"to_tuple"
,
global_env
=
{
"to_fixed_tuple"
:
to_fixed_tuple
},
),
# cache=True leads to a numba.cloudpickle dump failure in Python 3.10
# May be fine in Python 3.11, but I didn't test. It was fine in 3.12
cache
=
sys
.
version_info
>=
(
3
,
12
),
)
def
blockwise_wrapper
(
*
inputs_and_core_shapes
):
inputs
,
core_shapes
=
inputs_and_core_shapes
[:
nin
],
inputs_and_core_shapes
[
nin
:]
tuple_core_shapes
=
to_tuple
(
core_shapes
)
return
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(),
# constant_inputs
inputs
,
tuple_core_shapes
,
None
,
# size
)
)
def
blockwise
(
*
inputs_and_core_shapes
):
raise
NotImplementedError
(
"Non-jitted BlockwiseWithCoreShape not implemented"
)
raise
NotImplementedError
(
"Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
)
@overload
(
blockwise
,
jit_options
=
_jit_options
)
def
ov_blockwise
(
*
inputs_and_core_shapes
):
return
blockwise_wrapper
def
impl
(
*
inputs_and_core_shapes
):
inputs
,
core_shapes
=
(
inputs_and_core_shapes
[:
nin
],
inputs_and_core_shapes
[
nin
:],
)
tuple_core_shapes
=
to_tuple
(
core_shapes
)
return
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(),
# constant_inputs
inputs
,
tuple_core_shapes
,
None
,
# size
)
return
impl
return
blockwise
if
core_op_key
is
None
:
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key
=
None
else
:
blockwise_key
=
"_"
.
join
(
map
(
str
,
(
type
(
op
),
type
(
blockwise_op
),
tuple
(
blockwise_op
.
destroy_map
.
items
()),
blockwise_op
.
signature
,
input_bc_patterns
,
core_op_key
,
),
)
)
blockwise_key
=
sha256
(
blockwise_key
.
encode
())
.
hexdigest
()
return
blockwise
,
blockwise_key
pytensor/link/numba/dispatch/compile_ops.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
import
numpy
as
np
from
pytensor.compile.builders
import
OpFromGraph
...
...
@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from
pytensor.ifelse
import
IfElse
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify
,
numba_njit
,
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor.type
import
TensorType
@
numba_funcify.register
(
OpFromGraph
)
@
register_funcify_and_cache_key
(
OpFromGraph
)
def
numba_funcify_OpFromGraph
(
op
,
node
=
None
,
**
kwargs
):
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
...
...
@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace
=
True
,
)
NUMBA
.
optimizer
(
fgraph
)
return
numba_funcify
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
fgraph_fn
,
fgraph_cache_key
=
numba_funcify_and_cache_key
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
if
fgraph_cache_key
is
None
:
# Can't cache the inner graph
ofg_cache_key
=
None
else
:
ofg_cache_key
=
sha256
(
str
(
(
type
(
op
),
fgraph_cache_key
,
)
)
.
encode
()
)
.
hexdigest
()
return
fgraph_fn
,
ofg_cache_key
@
numba_funcify.register
(
TypeCastingOp
)
@
register_funcify_default_op_cache_key
(
TypeCastingOp
)
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
identity
(
x
):
...
...
@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs):
return
identity
@
numba_funcify.register
(
DeepCopyOp
)
@
register_funcify_default_op_cache_key
(
DeepCopyOp
)
def
numba_funcify_DeepCopyOp
(
op
,
node
,
**
kwargs
):
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
TensorType
):
...
...
@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return
deepcopy
@
numba_funcify.register
(
IfElse
)
@
register_funcify_default_op_cache_key
(
IfElse
)
def
numba_funcify_IfElse
(
op
,
**
kwargs
):
n_outs
=
op
.
n_outs
...
...
@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs):
return
ifelse
@
numba_funcify.register
(
CheckAndRaise
)
@
register_funcify_and_cache_key
(
CheckAndRaise
)
def
numba_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
error
=
op
.
exc_type
msg
=
op
.
msg
...
...
@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs):
raise
error
(
msg
)
return
x
return
check_and_raise
cache_key
=
sha256
(
str
((
type
(
op
),
error
,
msg
))
.
encode
())
.
hexdigest
()
return
check_and_raise
,
cache_key
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
6ac5ab28
差异被折叠。
点击展开。
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
6ac5ab28
import
warnings
from
hashlib
import
sha256
from
typing
import
cast
import
numba
...
...
@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
get_numba_type
,
numba_funcify
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor.extra_ops
import
(
...
...
@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import (
)
@
numba_funcify.register
(
Bartlett
)
@
register_funcify_default_op_cache_key
(
Bartlett
)
def
numba_funcify_Bartlett
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
bartlett
(
x
):
return
np
.
bartlett
(
x
.
item
())
return
bartlett
@
numba_funcify.register
(
CumOp
)
@
register_funcify_default_op_cache_key
(
CumOp
)
def
numba_funcify_CumOp
(
op
:
CumOp
,
node
:
Apply
,
**
kwargs
):
axis
=
op
.
axis
mode
=
op
.
mode
...
...
@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return
cumop
@
numba_funcify.register
(
FillDiagonal
)
@
register_funcify_default_op_cache_key
(
FillDiagonal
)
def
numba_funcify_FillDiagonal
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
filldiagonal
(
a
,
val
):
...
...
@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
return
filldiagonal
@
numba_funcify.register
(
FillDiagonalOffset
)
@
register_funcify_default_op_cache_key
(
FillDiagonalOffset
)
def
numba_funcify_FillDiagonalOffset
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
filldiagonaloffset
(
a
,
val
,
offset
):
...
...
@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
return
filldiagonaloffset
@
numba_funcify.register
(
RavelMultiIndex
)
@
register_funcify_default_op_cache_key
(
RavelMultiIndex
)
def
numba_funcify_RavelMultiIndex
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
order
=
op
.
order
...
...
@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
return
ravelmultiindex
@
numba_funcify.register
(
Repeat
)
@
register_funcify_default_op_cache_key
(
Repeat
)
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
a
,
_
=
node
.
inputs
...
...
@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if
axis
==
0
and
a
.
type
.
ndim
==
1
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
)
...
...
@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
return
generate_fallback_impl
(
op
,
node
)
@
numba_funcify.register
(
Unique
)
@
register_funcify_default_op_cache_key
(
Unique
)
def
numba_funcify_Unique
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
...
...
@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if
not
use_python
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
unique
(
x
):
return
np
.
unique
(
x
)
...
...
@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs):
return
unique
@
numba_funcify.register
(
UnravelIndex
)
@
register_funcify_and_cache_key
(
UnravelIndex
)
def
numba_funcify_UnravelIndex
(
op
,
node
,
**
kwargs
):
order
=
op
.
order
...
...
@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return
((
maybe_expand_dim
(
arr
)
//
a
)
%
shape
)
.
T
return
unravelindex
cache_key
=
sha256
(
str
((
type
(
op
),
op
.
order
,
len
(
node
.
outputs
)))
.
encode
()
)
.
hexdigest
()
return
unravelindex
,
cache_key
@
numba_funcify.register
(
SearchsortedOp
)
@
register_funcify_default_op_cache_key
(
SearchsortedOp
)
def
numba_funcify_Searchsorted
(
op
,
node
,
**
kwargs
):
side
=
op
.
side
...
...
@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
searchsorted
(
a
,
v
):
return
np
.
searchsorted
(
a
,
v
,
side
)
...
...
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
6ac5ab28
...
...
@@ -3,11 +3,11 @@ import warnings
import
numba
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
import
pytensor.link.numba.dispatch.
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
get_numba_type
,
int_to_float_fn
,
numba_funcif
y
,
register_funcify_default_op_cache_ke
y
,
)
from
pytensor.tensor.nlinalg
import
(
SVD
,
...
...
@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import (
)
@
numba_funcify.register
(
SVD
)
@
register_funcify_default_op_cache_key
(
SVD
)
def
numba_funcify_SVD
(
op
,
node
,
**
kwargs
):
full_matrices
=
op
.
full_matrices
compute_uv
=
op
.
compute_uv
...
...
@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs):
return
svd
@
numba_funcify.register
(
Det
)
@
register_funcify_default_op_cache_key
(
Det
)
def
numba_funcify_Det
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
det
(
x
):
return
np
.
array
(
np
.
linalg
.
det
(
inputs_cast
(
x
)))
.
astype
(
out_dtype
)
return
det
@
numba_funcify.register
(
SLogDet
)
@
register_funcify_default_op_cache_key
(
SLogDet
)
def
numba_funcify_SLogDet
(
op
,
node
,
**
kwargs
):
out_dtype_1
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype_2
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
...
...
@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs):
return
slogdet
@
numba_funcify.register
(
Eig
)
@
register_funcify_default_op_cache_key
(
Eig
)
def
numba_funcify_Eig
(
op
,
node
,
**
kwargs
):
w_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
w_dtype
)
...
...
@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs):
return
eig
@
numba_funcify.register
(
Eigh
)
@
register_funcify_default_op_cache_key
(
Eigh
)
def
numba_funcify_Eigh
(
op
,
node
,
**
kwargs
):
uplo
=
op
.
UPLO
...
...
@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs):
else
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
eigh
(
x
):
return
np
.
linalg
.
eigh
(
x
)
return
eigh
@
numba_funcify.register
(
MatrixInverse
)
@
register_funcify_default_op_cache_key
(
MatrixInverse
)
def
numba_funcify_MatrixInverse
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
matrix_inverse
(
x
):
return
np
.
linalg
.
inv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrix_inverse
@
numba_funcify.register
(
MatrixPinv
)
@
register_funcify_default_op_cache_key
(
MatrixPinv
)
def
numba_funcify_MatrixPinv
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
matrixpinv
(
x
):
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
6ac5ab28
from
collections.abc
import
Callable
from
copy
import
copy
,
deepcopy
from
functools
import
singledispatch
from
hashlib
import
sha256
from
textwrap
import
dedent
import
numba
...
...
@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr
from
pytensor.graph
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
direct_cast
,
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
direct_cast
,
numba_funcify
,
register_funcify_and_cache_key
,
)
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_vectorized
,
...
...
@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
)
@
numba_funcify.register
@
register_funcify_and_cache_key
(
RandomVariableWithCoreShape
)
def
numba_funcify_RandomVariable
(
op
:
RandomVariableWithCoreShape
,
node
,
**
kwargs
):
core_shape
=
node
.
inputs
[
0
]
...
...
@@ -423,28 +428,44 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_dtypes
=
encode_literals
((
rv_node
.
default_output
()
.
type
.
dtype
,))
inplace_pattern
=
encode_literals
(())
def
random_wrapper
(
core_shape
,
rng
,
size
,
*
dist_params
):
if
not
inplace
:
rng
=
copy
(
rng
)
draws
=
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(
rng
,),
dist_params
,
(
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
),),
None
if
size_len
is
None
else
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
),
)
return
rng
,
draws
def
random
(
core_shape
,
rng
,
size
,
*
dist_params
):
raise
NotImplementedError
(
"Non-jitted random variable not implemented"
)
raise
NotImplementedError
(
"Numba implementation of RandomVariable cannot be evaluated in Python (non-JIT) mode"
)
@overload
(
random
,
jit_options
=
_jit_options
)
def
ov_random
(
core_shape
,
rng
,
size
,
*
dist_params
):
return
random_wrapper
return
random
def
impl
(
core_shape
,
rng
,
size
,
*
dist_params
):
if
not
inplace
:
rng
=
copy
(
rng
)
draws
=
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(
rng
,),
dist_params
,
(
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
),),
None
if
size_len
is
None
else
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
),
)
return
rng
,
draws
return
impl
rv_op_props_dict
=
rv_op
.
props_dict
()
if
hasattr
(
rv_op
,
"props_dict"
)
else
{}
random_rv_key_contents
=
(
type
(
op
),
type
(
rv_op
),
rv_op
,
tuple
(
rv_op_props_dict
.
items
()),
size_len
,
core_shape_len
,
inplace
,
input_bc_patterns
,
)
random_rv_key
=
sha256
(
str
(
random_rv_key_contents
)
.
encode
())
.
hexdigest
()
return
random
,
random_rv_key
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
6ac5ab28
import
math
from
hashlib
import
sha256
import
numpy
as
np
from
pytensor.graph.basic
import
Variable
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
create_numba_signature
,
generate_fallback_impl
,
numba_funcify
,
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
from
pytensor.link.numba.dispatch.cython_support
import
wrap_cython_function
from
pytensor.link.utils
import
(
compile_function_src
,
get_name_for_object
,
unique_name_generator
,
)
...
...
@@ -30,13 +31,16 @@ from pytensor.scalar.basic import (
from
pytensor.scalar.math
import
Erf
,
Erfc
,
GammaLn
,
Log1mexp
,
Sigmoid
,
Softplus
@numba_funcify.register
(
ScalarOp
)
def
numba_funcify_ScalarOp
(
op
,
node
,
**
kwargs
):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
def
scalar_op_cache_key
(
op
):
# Scalar Ops don't have _props, because of their weird outputs_types_preference function
# So we create hash differently
return
sha256
(
str
(
type
(
op
))
.
encode
())
.
hexdigest
()
@register_funcify_and_cache_key
(
ScalarOp
)
def
numba_funcify_ScalarOp
(
op
,
node
,
**
kwargs
):
if
not
hasattr
(
op
,
"nfunc_spec"
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
return
generate_fallback_impl
(
op
,
node
=
node
,
**
kwargs
),
None
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_numba
=
None
...
...
@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
output_inner_dtype
=
None
# Cython functions might have an additional argument
cython_func
=
None
has_pyx_skip_dispatch
=
False
if
scalar_func_path
.
startswith
(
"scipy.special"
):
...
...
@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""
scalar_op_fn
=
compile_function_src
(
scalar_op_src
,
scalar_op_fn_name
,
{
**
globals
(),
**
global_env
}
scalar_op_fn
=
compile_numba_function_src
(
scalar_op_src
,
scalar_op_fn_name
,
{
**
globals
(),
**
global_env
},
)
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
# Functions that call a function pointer can't be cached
cache_key
=
None
if
cython_func
else
scalar_op_cache_key
(
op
)
return
numba_basic
.
numba_njit
(
scalar_op_fn
),
cache_key
return
numba_basic
.
numba_njit
(
signature
,
# Functions that call a function pointer can't be cached
cache
=
False
,
)(
scalar_op_fn
)
@numba_funcify.register
(
Switch
)
@register_funcify_and_cache_key
(
Switch
)
def
numba_funcify_Switch
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
switch
(
condition
,
x
,
y
):
...
...
@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs):
else
:
return
y
return
switch
return
switch
,
scalar_op_cache_key
(
op
)
def
binary_to_nary_func
(
inputs
:
list
[
Variable
],
binary_op_name
:
str
,
binary_op
:
str
):
...
...
@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn
=
compile_function_src
(
nary_src
,
binary_op_name
,
globals
())
nary_fn
=
compile_
numba_
function_src
(
nary_src
,
binary_op_name
,
globals
())
return
nary_fn
@
numba_funcify.register
(
Add
)
@
register_funcify_and_cache_key
(
Add
)
def
numba_funcify_Add
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"add"
,
"+"
)
return
numba_basic
.
numba_njit
(
signature
)(
nary_add_fn
)
return
numba_basic
.
numba_njit
(
nary_add_fn
),
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Mul
)
@
register_funcify_and_cache_key
(
Mul
)
def
numba_funcify_Mul
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"mul"
,
"*"
)
nary_mul_fn
=
binary_to_nary_func
(
node
.
inputs
,
"mul"
,
"*"
)
return
numba_basic
.
numba_njit
(
signature
)(
nary_add_fn
)
return
numba_basic
.
numba_njit
(
nary_mul_fn
),
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Cast
)
@
register_funcify_and_cache_key
(
Cast
)
def
numba_funcify_Cast
(
op
,
node
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
o_type
.
dtype
)
...
...
@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs):
def
cast
(
x
):
return
numba_basic
.
direct_cast
(
x
,
dtype
)
return
cast
return
cast
,
sha256
(
str
((
type
(
op
),
op
.
o_type
.
dtype
))
.
encode
())
.
hexdigest
()
@
numba_funcify.register
(
Identity
)
@
register_funcify_and_cache_key
(
Identity
)
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
identity
(
x
):
return
x
return
identity
return
identity
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Clip
)
@
register_funcify_and_cache_key
(
Clip
)
def
numba_funcify_Clip
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
clip
(
x
,
min_val
,
max_val
):
...
...
@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs):
else
:
return
x
return
clip
return
clip
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Composite
)
@
register_funcify_and_cache_key
(
Composite
)
def
numba_funcify_Composite
(
op
,
node
,
**
kwargs
):
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
return
numba_funcify
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
composite_fn
,
fgraph_key
=
numba_funcify_and_cache_key
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
if
fgraph_key
is
None
:
composite_key
=
None
else
:
composite_key
=
sha256
(
str
((
type
(
op
),
fgraph_key
))
.
encode
())
.
hexdigest
()
return
composite_fn
,
composite_key
@
numba_funcify.register
(
Second
)
@
register_funcify_and_cache_key
(
Second
)
def
numba_funcify_Second
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
second
(
x
,
y
):
return
y
return
second
return
second
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Reciprocal
)
@
register_funcify_and_cache_key
(
Reciprocal
)
def
numba_funcify_Reciprocal
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
reciprocal
(
x
):
...
...
@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
# `x` is an `int`
return
1
/
x
return
reciprocal
return
reciprocal
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Sigmoid
)
@
register_funcify_and_cache_key
(
Sigmoid
)
def
numba_funcify_Sigmoid
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
return
sigmoid
return
sigmoid
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
GammaLn
)
@
register_funcify_and_cache_key
(
GammaLn
)
def
numba_funcify_GammaLn
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
gammaln
(
x
):
return
math
.
lgamma
(
x
)
return
gammaln
return
gammaln
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Log1mexp
)
@
register_funcify_and_cache_key
(
Log1mexp
)
def
numba_funcify_Log1mexp
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
logp1mexp
(
x
):
...
...
@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
else
:
return
np
.
log
(
-
np
.
expm1
(
x
))
return
logp1mexp
return
logp1mexp
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Erf
)
@
register_funcify_and_cache_key
(
Erf
)
def
numba_funcify_Erf
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
erf
(
x
):
return
math
.
erf
(
x
)
return
erf
return
erf
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Erfc
)
@
register_funcify_and_cache_key
(
Erfc
)
def
numba_funcify_Erfc
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
erfc
(
x
):
return
math
.
erfc
(
x
)
return
erfc
return
erfc
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Softplus
)
@
register_funcify_and_cache_key
(
Softplus
)
def
numba_funcify_Softplus
(
op
,
node
,
**
kwargs
):
out_dtype
=
np
.
dtype
(
node
.
outputs
[
0
]
.
type
.
dtype
)
...
...
@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs):
value
=
x
return
numba_basic
.
direct_cast
(
value
,
out_dtype
)
return
softplus
return
softplus
,
scalar_op_cache_key
(
op
)
pytensor/link/numba/dispatch/scan.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
from
textwrap
import
dedent
,
indent
import
numpy
as
np
...
...
@@ -7,13 +8,14 @@ from numba.extending import overload
from
pytensor
import
In
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.mode
import
NUMBA
,
get_mode
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
create_arg_string
,
create_tuple_string
,
numba_funcify
,
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.scan.op
import
Scan
from
pytensor.tensor.type
import
TensorType
...
...
@@ -54,7 +56,7 @@ def array0d_range(x):
return
range_arr
@
numba_funcify.register
(
Scan
)
@
register_funcify_and_cache_key
(
Scan
)
def
numba_funcify_Scan
(
op
:
Scan
,
node
,
**
kwargs
):
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
...
...
@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
rewriter
(
fgraph
)
scan_inner_func
=
numba_funcif
y
(
op
.
fgraph
)
scan_inner_func
,
inner_func_cache_key
=
numba_funcify_and_cache_ke
y
(
op
.
fgraph
)
outer_in_names_to_vars
=
{
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
...
...
@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}):
"scan_inner_func"
:
scan_inner_func
,
}
scan_op_fn
=
compile_function_src
(
scan_op_src
,
"scan"
,
{
**
globals
(),
**
global_env
})
scan_op_fn
=
compile_numba_function_src
(
scan_op_src
,
"scan"
,
{
**
globals
(),
**
global_env
},
)
if
inner_func_cache_key
is
None
:
# If we can't cache the inner function, we can't cache the Scan either
scan_cache_key
=
None
else
:
scan_cache_key
=
sha256
(
f
"({scan_op_src}, {inner_func_cache_key})"
.
encode
()
)
.
hexdigest
()
return
numba_basic
.
numba_njit
(
scan_op_fn
,
boundscheck
=
False
)
return
numba_basic
.
numba_njit
(
scan_op_fn
,
boundscheck
=
False
)
,
scan_cache_key
pytensor/link/numba/dispatch/shape.py
浏览文件 @
6ac5ab28
...
...
@@ -4,14 +4,16 @@ import numpy as np
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
create_arg_string
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
(
create_arg_string
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
@
numba_funcify.register
(
Shape
)
@
register_funcify_default_op_cache_key
(
Shape
)
def
numba_funcify_Shape
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
shape
(
x
):
...
...
@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs):
return
shape
@
numba_funcify.register
(
Shape_i
)
@
register_funcify_default_op_cache_key
(
Shape_i
)
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
i
=
op
.
i
...
...
@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs):
return
shape_i
@
numba_funcify.register
(
SpecifyShape
)
@
register_funcify_default_op_cache_key
(
SpecifyShape
)
def
numba_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
shape_inputs
=
node
.
inputs
[
1
:]
shape_input_names
=
[
"shape_"
+
str
(
i
)
for
i
in
range
(
len
(
shape_inputs
))]
...
...
@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
)
specify_shape
=
compile_function_src
(
func
,
"specify_shape"
,
globals
())
return
numba_njit
(
specify_shape
)
return
numba_
basic
.
numba_
njit
(
specify_shape
)
@
numba_funcify.register
(
Reshape
)
@
register_funcify_default_op_cache_key
(
Reshape
)
def
numba_funcify_Reshape
(
op
,
**
kwargs
):
ndim
=
op
.
ndim
...
...
pytensor/link/numba/dispatch/signal/conv.py
浏览文件 @
6ac5ab28
...
...
@@ -2,11 +2,13 @@ import numpy as np
from
numba.np.arraymath
import
_get_inner_prod
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
register_funcify_default_op_cache_key
,
)
from
pytensor.tensor.signal.conv
import
Convolve1d
@
numba_funcify.register
(
Convolve1d
)
@
register_funcify_default_op_cache_key
(
Convolve1d
)
def
numba_funcify_Convolve1d
(
op
,
node
,
**
kwargs
):
# This specialized version is faster than the overloaded numba np.convolve
a_dtype
,
b_dtype
=
node
.
inputs
[
0
]
.
type
.
dtype
,
node
.
inputs
[
1
]
.
type
.
dtype
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
6ac5ab28
...
...
@@ -4,7 +4,10 @@ import numpy as np
from
pytensor
import
config
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
_lu_1
,
...
...
@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return
cholesky
@
numba_funcify.register
(
PivotToPermutations
)
@
register_funcify_default_op_cache_key
(
PivotToPermutations
)
def
pivot_to_permutation
(
op
,
node
,
**
kwargs
):
inverse
=
op
.
inverse
dtype
=
node
.
outputs
[
0
]
.
dtype
...
...
@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
lu
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return
lu_factor
@
numba_funcify.register
(
BlockDiagonal
)
@
register_funcify_default_op_cache_key
(
BlockDiagonal
)
def
numba_funcify_BlockDiagonal
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
outputs
[
0
]
.
dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit
def
block_diag
(
*
arrs
):
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
...
...
@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_basic.numba_njit
(
cache
=
False
)
@numba_basic.numba_njit
def
qr
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
pytensor/link/numba/dispatch/sort.py
浏览文件 @
6ac5ab28
...
...
@@ -3,11 +3,13 @@ import warnings
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
register_funcify_default_op_cache_key
,
)
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
@
numba_funcify.register
(
SortOp
)
@
register_funcify_default_op_cache_key
(
SortOp
)
def
numba_funcify_SortOp
(
op
,
node
,
**
kwargs
):
if
op
.
kind
!=
"quicksort"
:
warnings
.
warn
(
...
...
@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
return
sort_f
@
numba_funcify.register
(
ArgSortOp
)
@
register_funcify_default_op_cache_key
(
ArgSortOp
)
def
numba_funcify_ArgSortOp
(
op
,
node
,
**
kwargs
):
kind
=
op
.
kind
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
6ac5ab28
import
operator
import
sys
from
hashlib
import
sha256
import
numba
import
numpy
as
np
...
...
@@ -7,11 +8,17 @@ from llvmlite import ir
from
numba
import
types
from
numba.core.pythonapi
import
box
import
pytensor.link.numba.dispatch.basic
as
numba_basic
from
pytensor.graph
import
Type
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.numba.cache
import
(
compile_numba_function_src
,
)
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.utils
import
unique_name_generator
from
pytensor.tensor
import
TensorType
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
from
pytensor.tensor.subtensor
import
(
...
...
@@ -98,7 +105,7 @@ def enable_slice_boxing():
enable_slice_boxing
()
@
numba_funcify.register
(
MakeSlice
)
@
register_funcify_default_op_cache_key
(
MakeSlice
)
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
makeslice
(
*
x
):
...
...
@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs):
return
makeslice
@numba_funcify.register
(
Subtensor
)
@numba_funcify.register
(
IncSubtensor
)
@numba_funcify.register
(
AdvancedSubtensor1
)
def
subtensor_op_cache_key
(
op
,
**
extra_fields
):
key_parts
=
[
type
(
op
),
tuple
(
extra_fields
.
items
())]
if
hasattr
(
op
,
"idx_list"
):
idx_parts
=
[]
for
idx
in
op
.
idx_list
:
if
isinstance
(
idx
,
slice
):
idx_parts
.
append
(
(
idx
.
start
is
None
,
idx
.
stop
is
None
,
idx
.
step
is
None
,
)
)
else
:
idx_parts
.
append
(
"i"
)
key_parts
.
append
(
tuple
(
idx_parts
))
if
isinstance
(
op
,
IncSubtensor
|
AdvancedIncSubtensor
|
AdvancedIncSubtensor1
):
key_parts
.
append
((
op
.
inplace
,
op
.
set_instead_of_inc
))
if
isinstance
(
op
,
AdvancedIncSubtensor
):
key_parts
.
append
(
op
.
ignore_duplicates
)
return
sha256
(
str
(
tuple
(
key_parts
))
.
encode
())
.
hexdigest
()
@register_funcify_and_cache_key
(
Subtensor
)
@register_funcify_and_cache_key
(
IncSubtensor
)
@register_funcify_and_cache_key
(
AdvancedSubtensor1
)
def
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
):
"""Create a Python function that assembles and uses an index on an array."""
...
...
@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}):
return np.asarray(z)
"""
func
=
compile_function_src
(
func
=
compile_
numba_
function_src
(
subtensor_def_src
,
function_name
=
function_name
,
global_env
=
globals
()
|
{
"np"
:
np
},
)
return
numba_njit
(
func
,
boundscheck
=
True
)
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"numba_funcify_default_subtensor"
)
return
numba_basic
.
numba_njit
(
func
,
boundscheck
=
True
),
cache_key
@
numba_funcify.register
(
AdvancedSubtensor
)
@
numba_funcify.register
(
AdvancedIncSubtensor
)
@
register_funcify_and_cache_key
(
AdvancedSubtensor
)
@
register_funcify_and_cache_key
(
AdvancedIncSubtensor
)
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
if
isinstance
(
op
,
AdvancedSubtensor
):
_x
,
_y
,
idxs
=
node
.
inputs
[
0
],
None
,
node
.
inputs
[
1
:]
...
...
@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
)
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
),
subtensor_op_cache_key
(
op
,
func
=
"fallback_impl"
)
# What's left should all be supported natively by numba
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
...
...
@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing(
vector_indices
=
idxs
[
first_axis
:
after_last_axis
]
assert
all
(
v
.
type
.
broadcastable
==
(
False
,)
for
v
in
vector_indices
)
y_is_broadcasted
=
False
if
isinstance
(
op
,
AdvancedSubtensor
):
...
...
@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
return
out_buffer
ret
urn
advanced_subtensor_multiple_vector
ret
_func
=
advanced_subtensor_multiple_vector
else
:
inplace
=
op
.
inplace
...
...
@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
return
out
ret
urn
advanced_set_subtensor_multiple_vector
ret
_func
=
advanced_set_subtensor_multiple_vector
else
:
...
...
@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing(
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
return
out
return
advanced_inc_subtensor_multiple_vector
ret_func
=
advanced_inc_subtensor_multiple_vector
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"multiple_integer_vector_indexing"
,
y_is_broadcasted
=
y_is_broadcasted
,
)
return
ret_func
,
cache_key
@
numba_funcify.register
(
AdvancedIncSubtensor1
)
@
register_funcify_and_cache_key
(
AdvancedIncSubtensor1
)
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
set_instead_of_inc
=
op
.
set_instead_of_inc
...
...
@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x
[
idx
]
+=
val
return
x
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"numba_funcify_advancedincsubtensor1"
,
broadcast_with_index
=
broadcast_with_index
,
)
if
inplace
:
return
advancedincsubtensor1_inplace
return
advancedincsubtensor1_inplace
,
cache_key
else
:
...
...
@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x
=
x
.
copy
()
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
return
advancedincsubtensor1
return
advancedincsubtensor1
,
cache_key
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
from
textwrap
import
indent
import
numpy
as
np
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
create_tuple_string
,
numba_funcify
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.utils
import
unique_name_generator
from
pytensor.tensor.basic
import
(
Alloc
,
AllocEmpty
,
...
...
@@ -23,7 +26,7 @@ from pytensor.tensor.basic import (
)
@
numba_funcify.register
(
AllocEmpty
)
@
register_funcify_default_op_cache_key
(
AllocEmpty
)
def
numba_funcify_AllocEmpty
(
op
,
node
,
**
kwargs
):
global_env
=
{
"np"
:
np
,
...
...
@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}):
return np.empty(scalar_shape, dtype)
"""
alloc_fn
=
compile_function_src
(
alloc_fn
=
compile_
numba_
function_src
(
alloc_def_src
,
"allocempty"
,
{
**
globals
(),
**
global_env
}
)
return
numba_basic
.
numba_njit
(
alloc_fn
)
@
numba_funcify.register
(
Alloc
)
@
register_funcify_and_cache_key
(
Alloc
)
def
numba_funcify_Alloc
(
op
,
node
,
**
kwargs
):
global_env
=
{
"np"
:
np
}
...
...
@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}):
res[...] = val
return res
"""
alloc_fn
=
compile_function_src
(
alloc_def_src
,
"alloc"
,
{
**
globals
(),
**
global_env
})
alloc_fn
=
compile_numba_function_src
(
alloc_def_src
,
"alloc"
,
{
**
globals
(),
**
global_env
},
)
return
numba_basic
.
numba_njit
(
alloc_fn
)
cache_key
=
sha256
(
str
((
type
(
op
),
node
.
inputs
[
0
]
.
type
.
broadcastable
))
.
encode
()
)
.
hexdigest
()
return
numba_basic
.
numba_njit
(
alloc_fn
),
cache_key
@
numba_funcify.register
(
ARange
)
@
register_funcify_default_op_cache_key
(
ARange
)
def
numba_funcify_ARange
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
arange
(
start
,
stop
,
step
):
return
np
.
arange
(
start
.
item
(),
...
...
@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs):
return
arange
@
numba_funcify.register
(
Join
)
@
register_funcify_default_op_cache_key
(
Join
)
def
numba_funcify_Join
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
join
(
axis
,
*
tensors
):
...
...
@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs):
return
join
@
numba_funcify.register
(
Split
)
@
register_funcify_default_op_cache_key
(
Split
)
def
numba_funcify_Split
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
split
(
tensor
,
axis
,
indices
):
...
...
@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs):
return
split
@
numba_funcify.register
(
ExtractDiag
)
@
register_funcify_default_op_cache_key
(
ExtractDiag
)
def
numba_funcify_ExtractDiag
(
op
,
node
,
**
kwargs
):
view
=
op
.
view
axis1
,
axis2
,
offset
=
op
.
axis1
,
op
.
axis2
,
op
.
offset
if
node
.
inputs
[
0
]
.
type
.
ndim
==
2
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
extract_diag
(
x
):
out
=
np
.
diag
(
x
,
k
=
offset
)
...
...
@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
leading_dims
=
(
slice
(
None
),)
*
axis1
middle_dims
=
(
slice
(
None
),)
*
(
axis2
-
axis1
-
1
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
extract_diag
(
x
):
if
offset
>=
0
:
diag_len
=
min
(
x
.
shape
[
axis1
],
max
(
0
,
x
.
shape
[
axis2
]
-
offset
))
...
...
@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
return
extract_diag
@
numba_funcify.register
(
Eye
)
@
register_funcify_default_op_cache_key
(
Eye
)
def
numba_funcify_Eye
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
eye
(
N
,
M
,
k
):
return
np
.
eye
(
N
.
item
(),
...
...
@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs):
return
eye
@
numba_funcify.register
(
MakeVector
)
@
register_funcify_default_op_cache_key
(
MakeVector
)
def
numba_funcify_MakeVector
(
op
,
node
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
...
...
@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=dtype)
"""
makevector_fn
=
compile_function_src
(
makevector_def_src
,
"makevector"
,
{
**
globals
(),
**
global_env
}
makevector_fn
=
compile_numba_function_src
(
makevector_def_src
,
"makevector"
,
{
**
globals
(),
**
global_env
},
)
return
numba_basic
.
numba_njit
(
makevector_fn
)
@
numba_funcify.register
(
TensorFromScalar
)
@
register_funcify_default_op_cache_key
(
TensorFromScalar
)
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
tensor_from_scalar
(
x
):
return
np
.
array
(
x
)
return
tensor_from_scalar
@
numba_funcify.register
(
ScalarFromTensor
)
@
register_funcify_default_op_cache_key
(
ScalarFromTensor
)
def
numba_funcify_ScalarFromTensor
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
scalar_from_tensor
(
x
):
return
x
.
item
()
return
scalar_from_tensor
@
numba_funcify.register
(
Nonzero
)
@
register_funcify_default_op_cache_key
(
Nonzero
)
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
def
nonzero
(
a
):
...
...
pytensor/link/numba/dispatch/vectorize_codegen.py
浏览文件 @
6ac5ab28
...
...
@@ -4,7 +4,7 @@ import base64
import
pickle
from
collections.abc
import
Callable
,
Sequence
from
textwrap
import
indent
from
typing
import
Any
,
cast
from
typing
import
Any
import
numba
import
numpy
as
np
...
...
@@ -15,8 +15,8 @@ from numba.core.base import BaseContext
from
numba.core.types.misc
import
NoneType
from
numba.np
import
arrayobj
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.utils
import
compile_function_src
def
encode_literals
(
literals
:
Sequence
)
->
str
:
...
...
@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}):
{indent(store_outputs, " " * 4)}
"""
global_env
=
{
"core_op_fn"
:
core_op_fn
}
func
=
compile_function_src
(
func_src
,
"store_core_outputs"
,
{
**
globals
(),
**
global_env
}
func
=
compile_numba_function_src
(
func_src
,
"store_core_outputs"
,
{
**
globals
(),
**
global_env
},
)
return
cast
(
Callable
,
numba_basic
.
numba_njit
(
func
)
)
return
numba_basic
.
numba_njit
(
func
)
_jit_options
=
{
...
...
@@ -74,7 +77,7 @@ _jit_options = {
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
def
_vectorized
(
typingctx
,
scalar
_func
,
core
_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
...
...
@@ -85,7 +88,7 @@ def _vectorized(
size_type
,
):
arg_types
=
[
scalar
_func
,
core
_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
...
...
@@ -173,16 +176,6 @@ def _vectorized(
)
out_types
[
output_idx
]
=
output_type
core_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
*
constant_inputs_types
,
*
core_input_types
,
*
core_out_types
,
],
{},
)
ret_type
=
types
.
Tuple
(
out_types
)
if
len
(
output_dtypes
)
==
1
:
...
...
@@ -239,11 +232,21 @@ def _vectorized(
output_core_shapes
,
)
core_signature
=
typingctx
.
resolve_function_type
(
core_func
,
[
*
constant_inputs_types
,
*
core_input_types
,
*
core_out_types
,
],
{},
)
make_loop_call
(
typingctx
,
ctx
,
builder
,
scalar
_func
,
core
_func
,
core_signature
,
iter_shape
,
constant_inputs
,
...
...
@@ -416,8 +419,8 @@ def make_loop_call(
typingctx
,
context
:
numba
.
core
.
base
.
BaseContext
,
builder
:
ir
.
IRBuilder
,
scalar
_func
:
Any
,
scalar
_signature
:
types
.
FunctionType
,
core
_func
:
Any
,
core
_signature
:
types
.
FunctionType
,
iter_shape
:
tuple
[
ir
.
Instruction
,
...
],
constant_inputs
:
tuple
[
ir
.
Instruction
,
...
],
inputs
:
tuple
[
ir
.
Instruction
,
...
],
...
...
@@ -557,10 +560,10 @@ def make_loop_call(
val
=
core_array
.
_getvalue
()
output_slices
.
append
(
val
)
inner_codegen
=
context
.
get_function
(
scalar_func
,
scalar
_signature
)
inner_codegen
=
context
.
get_function
(
core_func
,
core
_signature
)
if
isinstance
(
scalar
_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
input_vals
=
[
context
.
make_tuple
(
builder
,
scalar
_signature
.
args
[
0
],
input_vals
)]
if
isinstance
(
core
_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
input_vals
=
[
context
.
make_tuple
(
builder
,
core
_signature
.
args
[
0
],
input_vals
)]
inner_codegen
(
builder
,
[
*
constant_inputs
,
*
input_vals
,
*
output_slices
])
...
...
tests/link/numba/signal/test_conv.py
浏览文件 @
6ac5ab28
...
...
@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
@pytest.mark.parametrize
(
"bcast_order"
,
(
1
,
0
))
@pytest.mark.parametrize
(
"mode"
,
[
"full"
,
"valid"
,
"same"
])
@pytest.mark.parametrize
(
"x_smaller"
,
(
False
,
True
))
def
test_convolve1d
(
x_smaller
,
mode
):
def
test_convolve1d
(
mode
,
bcast_order
):
x
=
dmatrix
(
"x"
)
y
=
dmatrix
(
"y"
)
if
x_smaller
:
out
=
convolve1d
(
x
[
None
],
y
[:,
None
],
mode
=
mode
)
# Testing two orders because this revealed a bug in the past
if
bcast_order
==
0
:
out
=
convolve1d
(
x
[:,
None
],
y
[
None
,
:],
mode
=
mode
)
else
:
out
=
convolve1d
(
y
[:,
None
],
x
[
None
],
mode
=
mode
)
out
=
convolve1d
(
x
[
None
],
y
[:,
None
],
mode
=
mode
)
rng
=
np
.
random
.
default_rng
()
test_x
=
rng
.
normal
(
size
=
(
3
,
5
))
test_y
=
rng
.
normal
(
size
=
(
7
,
11
))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py
([
x
,
y
],
out
,
[
test_x
,
test_y
],
eval_obj_mode
=
False
)
numba_fn
,
res
=
compare_numba_and_py
(
[
x
,
y
],
out
,
[
test_x
,
test_y
],
eval_obj_mode
=
False
)
# Try other order of inputs, as implementation depends on it
# Result should be the same, just in different order, except for 'same' mode
if
mode
!=
"same"
:
np
.
testing
.
assert_allclose
(
np
.
swapaxes
(
numba_fn
(
test_y
,
test_x
),
0
,
1
),
res
,
)
@pytest.mark.parametrize
(
"mode"
,
(
"full"
,
"valid"
),
ids
=
lambda
x
:
f
"mode={x}"
)
...
...
tests/link/numba/test_basic.py
浏览文件 @
6ac5ab28
...
...
@@ -402,7 +402,9 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"jitable_func"
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
==
{
"afn"
,
"arcp"
,
...
...
@@ -413,7 +415,9 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
False
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"jitable_func"
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
is
False
...
...
@@ -422,9 +426,10 @@ def test_config_options_cached():
with
config
.
change_flags
(
numba__cache
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
# Caching is disabled unless the dispatched function returns an explicit cache key
assert
isinstance
(
numba_sum_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"jitable_func"
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
not
isinstance
(
numba_sum_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
with
config
.
change_flags
(
numba__cache
=
False
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论