Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6ac5ab28
提交
6ac5ab28
authored
11月 03, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 16, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Cache keys for numba Op dispatches
上级
74ab0383
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
17 个修改的文件
包含
419 行增加
和
240 行删除
+419
-240
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+51
-28
compile_ops.py
pytensor/link/numba/dispatch/compile_ops.py
+30
-9
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+0
-0
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+21
-15
nlinalg.py
pytensor/link/numba/dispatch/nlinalg.py
+13
-13
random.py
pytensor/link/numba/dispatch/random.py
+44
-23
scalar.py
pytensor/link/numba/dispatch/scalar.py
+59
-51
scan.py
pytensor/link/numba/dispatch/scan.py
+20
-6
shape.py
pytensor/link/numba/dispatch/shape.py
+9
-7
conv.py
pytensor/link/numba/dispatch/signal/conv.py
+4
-2
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+8
-6
sort.py
pytensor/link/numba/dispatch/sort.py
+5
-3
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+66
-19
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+36
-24
vectorize_codegen.py
pytensor/link/numba/dispatch/vectorize_codegen.py
+26
-23
test_conv.py
tests/link/numba/signal/test_conv.py
+17
-6
test_basic.py
tests/link/numba/test_basic.py
+10
-5
没有找到文件。
pytensor/link/numba/dispatch/blockwise.py
浏览文件 @
6ac5ab28
import
sys
from
hashlib
import
sha256
from
typing
import
cast
from
typing
import
cast
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_jit_options
,
_vectorized
,
_vectorized
,
encode_literals
,
encode_literals
,
store_core_outputs
,
store_core_outputs
,
)
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.tensor
import
TensorVariable
,
get_vector_length
from
pytensor.tensor
import
TensorVariable
,
get_vector_length
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
@
numba_funcify.register
(
BlockwiseWithCoreShape
)
@
register_funcify_and_cache_key
(
BlockwiseWithCoreShape
)
def
numba_funcify_Blockwise
(
op
:
BlockwiseWithCoreShape
,
node
,
**
kwargs
):
def
numba_funcify_Blockwise
(
op
:
BlockwiseWithCoreShape
,
node
,
**
kwargs
):
[
blockwise_node
]
=
op
.
fgraph
.
apply_nodes
[
blockwise_node
]
=
op
.
fgraph
.
apply_nodes
blockwise_op
:
Blockwise
=
blockwise_node
.
op
blockwise_op
:
Blockwise
=
blockwise_node
.
op
...
@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
...
@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
cast
(
tuple
[
TensorVariable
],
node
.
inputs
[:
nin
]),
cast
(
tuple
[
TensorVariable
],
node
.
inputs
[:
nin
]),
propagate_unbatched_core_inputs
=
True
,
propagate_unbatched_core_inputs
=
True
,
)
)
core_op_fn
=
numba_funcif
y
(
core_op_fn
,
core_op_key
=
numba_funcify_and_cache_ke
y
(
core_op
,
core_op
,
node
=
core_node
,
node
=
core_node
,
parent_node
=
node
,
parent_node
=
node
,
...
@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
...
@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
src
+=
")"
src
+=
")"
to_tuple
=
numba_basic
.
numba_njit
(
to_tuple
=
numba_basic
.
numba_njit
(
compile_function_src
(
compile_
numba_
function_src
(
src
,
src
,
"to_tuple"
,
"to_tuple"
,
global_env
=
{
"to_fixed_tuple"
:
to_fixed_tuple
},
global_env
=
{
"to_fixed_tuple"
:
to_fixed_tuple
},
),
# cache=True leads to a numba.cloudpickle dump failure in Python 3.10
# May be fine in Python 3.11, but I didn't test. It was fine in 3.12
cache
=
sys
.
version_info
>=
(
3
,
12
),
)
def
blockwise_wrapper
(
*
inputs_and_core_shapes
):
inputs
,
core_shapes
=
inputs_and_core_shapes
[:
nin
],
inputs_and_core_shapes
[
nin
:]
tuple_core_shapes
=
to_tuple
(
core_shapes
)
return
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(),
# constant_inputs
inputs
,
tuple_core_shapes
,
None
,
# size
)
)
)
def
blockwise
(
*
inputs_and_core_shapes
):
def
blockwise
(
*
inputs_and_core_shapes
):
raise
NotImplementedError
(
"Non-jitted BlockwiseWithCoreShape not implemented"
)
raise
NotImplementedError
(
"Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
)
@overload
(
blockwise
,
jit_options
=
_jit_options
)
@overload
(
blockwise
,
jit_options
=
_jit_options
)
def
ov_blockwise
(
*
inputs_and_core_shapes
):
def
ov_blockwise
(
*
inputs_and_core_shapes
):
return
blockwise_wrapper
def
impl
(
*
inputs_and_core_shapes
):
inputs
,
core_shapes
=
(
inputs_and_core_shapes
[:
nin
],
inputs_and_core_shapes
[
nin
:],
)
tuple_core_shapes
=
to_tuple
(
core_shapes
)
return
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(),
# constant_inputs
inputs
,
tuple_core_shapes
,
None
,
# size
)
return
impl
return
blockwise
if
core_op_key
is
None
:
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
blockwise_key
=
None
else
:
blockwise_key
=
"_"
.
join
(
map
(
str
,
(
type
(
op
),
type
(
blockwise_op
),
tuple
(
blockwise_op
.
destroy_map
.
items
()),
blockwise_op
.
signature
,
input_bc_patterns
,
core_op_key
,
),
)
)
blockwise_key
=
sha256
(
blockwise_key
.
encode
())
.
hexdigest
()
return
blockwise
,
blockwise_key
pytensor/link/numba/dispatch/compile_ops.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
import
numpy
as
np
import
numpy
as
np
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
...
@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
...
@@ -8,14 +10,15 @@ from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from
pytensor.ifelse
import
IfElse
from
pytensor.ifelse
import
IfElse
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify
,
numba_funcify_and_cache_key
,
numba_njit
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
)
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
@
numba_funcify.register
(
OpFromGraph
)
@
register_funcify_and_cache_key
(
OpFromGraph
)
def
numba_funcify_OpFromGraph
(
op
,
node
=
None
,
**
kwargs
):
def
numba_funcify_OpFromGraph
(
op
,
node
=
None
,
**
kwargs
):
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
...
@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
...
@@ -30,10 +33,27 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
accept_inplace
=
True
,
accept_inplace
=
True
,
)
)
NUMBA
.
optimizer
(
fgraph
)
NUMBA
.
optimizer
(
fgraph
)
return
numba_funcify
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
fgraph_fn
,
fgraph_cache_key
=
numba_funcify_and_cache_key
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
if
fgraph_cache_key
is
None
:
# Can't cache the inner graph
ofg_cache_key
=
None
else
:
ofg_cache_key
=
sha256
(
str
(
(
type
(
op
),
fgraph_cache_key
,
)
)
.
encode
()
)
.
hexdigest
()
return
fgraph_fn
,
ofg_cache_key
@
numba_funcify.register
(
TypeCastingOp
)
@
register_funcify_default_op_cache_key
(
TypeCastingOp
)
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
identity
(
x
):
def
identity
(
x
):
...
@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs):
...
@@ -42,7 +62,7 @@ def numba_funcify_type_casting(op, **kwargs):
return
identity
return
identity
@
numba_funcify.register
(
DeepCopyOp
)
@
register_funcify_default_op_cache_key
(
DeepCopyOp
)
def
numba_funcify_DeepCopyOp
(
op
,
node
,
**
kwargs
):
def
numba_funcify_DeepCopyOp
(
op
,
node
,
**
kwargs
):
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
TensorType
):
if
isinstance
(
node
.
inputs
[
0
]
.
type
,
TensorType
):
...
@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
...
@@ -59,7 +79,7 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return
deepcopy
return
deepcopy
@
numba_funcify.register
(
IfElse
)
@
register_funcify_default_op_cache_key
(
IfElse
)
def
numba_funcify_IfElse
(
op
,
**
kwargs
):
def
numba_funcify_IfElse
(
op
,
**
kwargs
):
n_outs
=
op
.
n_outs
n_outs
=
op
.
n_outs
...
@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs):
...
@@ -88,7 +108,7 @@ def numba_funcify_IfElse(op, **kwargs):
return
ifelse
return
ifelse
@
numba_funcify.register
(
CheckAndRaise
)
@
register_funcify_and_cache_key
(
CheckAndRaise
)
def
numba_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
def
numba_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
error
=
op
.
exc_type
error
=
op
.
exc_type
msg
=
op
.
msg
msg
=
op
.
msg
...
@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs):
...
@@ -100,4 +120,5 @@ def numba_funcify_CheckAndRaise(op, node, **kwargs):
raise
error
(
msg
)
raise
error
(
msg
)
return
x
return
x
return
check_and_raise
cache_key
=
sha256
(
str
((
type
(
op
),
error
,
msg
))
.
encode
())
.
hexdigest
()
return
check_and_raise
,
cache_key
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
6ac5ab28
差异被折叠。
点击展开。
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
6ac5ab28
import
warnings
import
warnings
from
hashlib
import
sha256
from
typing
import
cast
from
typing
import
cast
import
numba
import
numba
...
@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic
...
@@ -9,7 +10,8 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
generate_fallback_impl
,
get_numba_type
,
get_numba_type
,
numba_funcify
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
)
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor.extra_ops
import
(
from
pytensor.tensor.extra_ops
import
(
...
@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import (
...
@@ -25,16 +27,16 @@ from pytensor.tensor.extra_ops import (
)
)
@
numba_funcify.register
(
Bartlett
)
@
register_funcify_default_op_cache_key
(
Bartlett
)
def
numba_funcify_Bartlett
(
op
,
**
kwargs
):
def
numba_funcify_Bartlett
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
bartlett
(
x
):
def
bartlett
(
x
):
return
np
.
bartlett
(
x
.
item
())
return
np
.
bartlett
(
x
.
item
())
return
bartlett
return
bartlett
@
numba_funcify.register
(
CumOp
)
@
register_funcify_default_op_cache_key
(
CumOp
)
def
numba_funcify_CumOp
(
op
:
CumOp
,
node
:
Apply
,
**
kwargs
):
def
numba_funcify_CumOp
(
op
:
CumOp
,
node
:
Apply
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
mode
=
op
.
mode
mode
=
op
.
mode
...
@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
...
@@ -94,7 +96,7 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
return
cumop
return
cumop
@
numba_funcify.register
(
FillDiagonal
)
@
register_funcify_default_op_cache_key
(
FillDiagonal
)
def
numba_funcify_FillDiagonal
(
op
,
**
kwargs
):
def
numba_funcify_FillDiagonal
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
filldiagonal
(
a
,
val
):
def
filldiagonal
(
a
,
val
):
...
@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
...
@@ -104,7 +106,7 @@ def numba_funcify_FillDiagonal(op, **kwargs):
return
filldiagonal
return
filldiagonal
@
numba_funcify.register
(
FillDiagonalOffset
)
@
register_funcify_default_op_cache_key
(
FillDiagonalOffset
)
def
numba_funcify_FillDiagonalOffset
(
op
,
node
,
**
kwargs
):
def
numba_funcify_FillDiagonalOffset
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
filldiagonaloffset
(
a
,
val
,
offset
):
def
filldiagonaloffset
(
a
,
val
,
offset
):
...
@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
...
@@ -129,7 +131,7 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
return
filldiagonaloffset
return
filldiagonaloffset
@
numba_funcify.register
(
RavelMultiIndex
)
@
register_funcify_default_op_cache_key
(
RavelMultiIndex
)
def
numba_funcify_RavelMultiIndex
(
op
,
node
,
**
kwargs
):
def
numba_funcify_RavelMultiIndex
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
mode
=
op
.
mode
order
=
op
.
order
order
=
op
.
order
...
@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
...
@@ -194,7 +196,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
return
ravelmultiindex
return
ravelmultiindex
@
numba_funcify.register
(
Repeat
)
@
register_funcify_default_op_cache_key
(
Repeat
)
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
a
,
_
=
node
.
inputs
a
,
_
=
node
.
inputs
...
@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
...
@@ -202,7 +204,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if
axis
==
0
and
a
.
type
.
ndim
==
1
:
if
axis
==
0
and
a
.
type
.
ndim
==
1
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
repeatop
(
x
,
repeats
):
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
)
return
np
.
repeat
(
x
,
repeats
)
...
@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
...
@@ -212,7 +214,7 @@ def numba_funcify_Repeat(op, node, **kwargs):
return
generate_fallback_impl
(
op
,
node
)
return
generate_fallback_impl
(
op
,
node
)
@
numba_funcify.register
(
Unique
)
@
register_funcify_default_op_cache_key
(
Unique
)
def
numba_funcify_Unique
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Unique
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
...
@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs):
...
@@ -230,7 +232,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if
not
use_python
:
if
not
use_python
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
unique
(
x
):
def
unique
(
x
):
return
np
.
unique
(
x
)
return
np
.
unique
(
x
)
...
@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs):
...
@@ -257,7 +259,7 @@ def numba_funcify_Unique(op, node, **kwargs):
return
unique
return
unique
@
numba_funcify.register
(
UnravelIndex
)
@
register_funcify_and_cache_key
(
UnravelIndex
)
def
numba_funcify_UnravelIndex
(
op
,
node
,
**
kwargs
):
def
numba_funcify_UnravelIndex
(
op
,
node
,
**
kwargs
):
order
=
op
.
order
order
=
op
.
order
...
@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
...
@@ -289,10 +291,14 @@ def numba_funcify_UnravelIndex(op, node, **kwargs):
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return
((
maybe_expand_dim
(
arr
)
//
a
)
%
shape
)
.
T
return
((
maybe_expand_dim
(
arr
)
//
a
)
%
shape
)
.
T
return
unravelindex
cache_key
=
sha256
(
str
((
type
(
op
),
op
.
order
,
len
(
node
.
outputs
)))
.
encode
()
)
.
hexdigest
()
return
unravelindex
,
cache_key
@
numba_funcify.register
(
SearchsortedOp
)
@
register_funcify_default_op_cache_key
(
SearchsortedOp
)
def
numba_funcify_Searchsorted
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Searchsorted
(
op
,
node
,
**
kwargs
):
side
=
op
.
side
side
=
op
.
side
...
@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
...
@@ -319,7 +325,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else
:
else
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
searchsorted
(
a
,
v
):
def
searchsorted
(
a
,
v
):
return
np
.
searchsorted
(
a
,
v
,
side
)
return
np
.
searchsorted
(
a
,
v
,
side
)
...
...
pytensor/link/numba/dispatch/nlinalg.py
浏览文件 @
6ac5ab28
...
@@ -3,11 +3,11 @@ import warnings
...
@@ -3,11 +3,11 @@ import warnings
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
import
pytensor.link.numba.dispatch.
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
get_numba_type
,
get_numba_type
,
int_to_float_fn
,
int_to_float_fn
,
numba_funcif
y
,
register_funcify_default_op_cache_ke
y
,
)
)
from
pytensor.tensor.nlinalg
import
(
from
pytensor.tensor.nlinalg
import
(
SVD
,
SVD
,
...
@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import (
...
@@ -20,7 +20,7 @@ from pytensor.tensor.nlinalg import (
)
)
@
numba_funcify.register
(
SVD
)
@
register_funcify_default_op_cache_key
(
SVD
)
def
numba_funcify_SVD
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SVD
(
op
,
node
,
**
kwargs
):
full_matrices
=
op
.
full_matrices
full_matrices
=
op
.
full_matrices
compute_uv
=
op
.
compute_uv
compute_uv
=
op
.
compute_uv
...
@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs):
...
@@ -44,19 +44,19 @@ def numba_funcify_SVD(op, node, **kwargs):
return
svd
return
svd
@
numba_funcify.register
(
Det
)
@
register_funcify_default_op_cache_key
(
Det
)
def
numba_funcify_Det
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Det
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
det
(
x
):
def
det
(
x
):
return
np
.
array
(
np
.
linalg
.
det
(
inputs_cast
(
x
)))
.
astype
(
out_dtype
)
return
np
.
array
(
np
.
linalg
.
det
(
inputs_cast
(
x
)))
.
astype
(
out_dtype
)
return
det
return
det
@
numba_funcify.register
(
SLogDet
)
@
register_funcify_default_op_cache_key
(
SLogDet
)
def
numba_funcify_SLogDet
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SLogDet
(
op
,
node
,
**
kwargs
):
out_dtype_1
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype_1
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype_2
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
out_dtype_2
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
...
@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs):
...
@@ -74,7 +74,7 @@ def numba_funcify_SLogDet(op, node, **kwargs):
return
slogdet
return
slogdet
@
numba_funcify.register
(
Eig
)
@
register_funcify_default_op_cache_key
(
Eig
)
def
numba_funcify_Eig
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Eig
(
op
,
node
,
**
kwargs
):
w_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
w_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
w_dtype
)
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
w_dtype
)
...
@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs):
...
@@ -86,7 +86,7 @@ def numba_funcify_Eig(op, node, **kwargs):
return
eig
return
eig
@
numba_funcify.register
(
Eigh
)
@
register_funcify_default_op_cache_key
(
Eigh
)
def
numba_funcify_Eigh
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Eigh
(
op
,
node
,
**
kwargs
):
uplo
=
op
.
UPLO
uplo
=
op
.
UPLO
...
@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs):
...
@@ -113,31 +113,31 @@ def numba_funcify_Eigh(op, node, **kwargs):
else
:
else
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
eigh
(
x
):
def
eigh
(
x
):
return
np
.
linalg
.
eigh
(
x
)
return
np
.
linalg
.
eigh
(
x
)
return
eigh
return
eigh
@
numba_funcify.register
(
MatrixInverse
)
@
register_funcify_default_op_cache_key
(
MatrixInverse
)
def
numba_funcify_MatrixInverse
(
op
,
node
,
**
kwargs
):
def
numba_funcify_MatrixInverse
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
matrix_inverse
(
x
):
def
matrix_inverse
(
x
):
return
np
.
linalg
.
inv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
np
.
linalg
.
inv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrix_inverse
return
matrix_inverse
@
numba_funcify.register
(
MatrixPinv
)
@
register_funcify_default_op_cache_key
(
MatrixPinv
)
def
numba_funcify_MatrixPinv
(
op
,
node
,
**
kwargs
):
def
numba_funcify_MatrixPinv
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
matrixpinv
(
x
):
def
matrixpinv
(
x
):
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
6ac5ab28
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
functools
import
singledispatch
from
functools
import
singledispatch
from
hashlib
import
sha256
from
textwrap
import
dedent
from
textwrap
import
dedent
import
numba
import
numba
...
@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr
...
@@ -13,7 +14,11 @@ import pytensor.tensor.random.basic as ptr
from
pytensor.graph
import
Apply
from
pytensor.graph
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
direct_cast
,
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
direct_cast
,
numba_funcify
,
register_funcify_and_cache_key
,
)
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_jit_options
,
_vectorized
,
_vectorized
,
...
@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
...
@@ -395,7 +400,7 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
)
)
@
numba_funcify.register
@
register_funcify_and_cache_key
(
RandomVariableWithCoreShape
)
def
numba_funcify_RandomVariable
(
op
:
RandomVariableWithCoreShape
,
node
,
**
kwargs
):
def
numba_funcify_RandomVariable
(
op
:
RandomVariableWithCoreShape
,
node
,
**
kwargs
):
core_shape
=
node
.
inputs
[
0
]
core_shape
=
node
.
inputs
[
0
]
...
@@ -423,28 +428,44 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
...
@@ -423,28 +428,44 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
output_dtypes
=
encode_literals
((
rv_node
.
default_output
()
.
type
.
dtype
,))
output_dtypes
=
encode_literals
((
rv_node
.
default_output
()
.
type
.
dtype
,))
inplace_pattern
=
encode_literals
(())
inplace_pattern
=
encode_literals
(())
def
random_wrapper
(
core_shape
,
rng
,
size
,
*
dist_params
):
if
not
inplace
:
rng
=
copy
(
rng
)
draws
=
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(
rng
,),
dist_params
,
(
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
),),
None
if
size_len
is
None
else
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
),
)
return
rng
,
draws
def
random
(
core_shape
,
rng
,
size
,
*
dist_params
):
def
random
(
core_shape
,
rng
,
size
,
*
dist_params
):
raise
NotImplementedError
(
"Non-jitted random variable not implemented"
)
raise
NotImplementedError
(
"Numba implementation of RandomVariable cannot be evaluated in Python (non-JIT) mode"
)
@overload
(
random
,
jit_options
=
_jit_options
)
@overload
(
random
,
jit_options
=
_jit_options
)
def
ov_random
(
core_shape
,
rng
,
size
,
*
dist_params
):
def
ov_random
(
core_shape
,
rng
,
size
,
*
dist_params
):
return
random_wrapper
def
impl
(
core_shape
,
rng
,
size
,
*
dist_params
):
if
not
inplace
:
return
random
rng
=
copy
(
rng
)
draws
=
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(
rng
,),
dist_params
,
(
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
),),
None
if
size_len
is
None
else
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
),
)
return
rng
,
draws
return
impl
rv_op_props_dict
=
rv_op
.
props_dict
()
if
hasattr
(
rv_op
,
"props_dict"
)
else
{}
random_rv_key_contents
=
(
type
(
op
),
type
(
rv_op
),
rv_op
,
tuple
(
rv_op_props_dict
.
items
()),
size_len
,
core_shape_len
,
inplace
,
input_bc_patterns
,
)
random_rv_key
=
sha256
(
str
(
random_rv_key_contents
)
.
encode
())
.
hexdigest
()
return
random
,
random_rv_key
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
6ac5ab28
import
math
import
math
from
hashlib
import
sha256
import
numpy
as
np
import
numpy
as
np
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.basic
import
Variable
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
create_numba_signature
,
generate_fallback_impl
,
generate_fallback_impl
,
numba_funcify
,
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
)
from
pytensor.link.numba.dispatch.cython_support
import
wrap_cython_function
from
pytensor.link.numba.dispatch.cython_support
import
wrap_cython_function
from
pytensor.link.utils
import
(
from
pytensor.link.utils
import
(
compile_function_src
,
get_name_for_object
,
get_name_for_object
,
unique_name_generator
,
unique_name_generator
,
)
)
...
@@ -30,13 +31,16 @@ from pytensor.scalar.basic import (
...
@@ -30,13 +31,16 @@ from pytensor.scalar.basic import (
from
pytensor.scalar.math
import
Erf
,
Erfc
,
GammaLn
,
Log1mexp
,
Sigmoid
,
Softplus
from
pytensor.scalar.math
import
Erf
,
Erfc
,
GammaLn
,
Log1mexp
,
Sigmoid
,
Softplus
@numba_funcify.register
(
ScalarOp
)
def
scalar_op_cache_key
(
op
):
def
numba_funcify_ScalarOp
(
op
,
node
,
**
kwargs
):
# Scalar Ops don't have _props, because of their weird outputs_types_preference function
# TODO: Do we need to cache these functions so that we don't end up
# So we create hash differently
# compiling the same Numba function over and over again?
return
sha256
(
str
(
type
(
op
))
.
encode
())
.
hexdigest
()
@register_funcify_and_cache_key
(
ScalarOp
)
def
numba_funcify_ScalarOp
(
op
,
node
,
**
kwargs
):
if
not
hasattr
(
op
,
"nfunc_spec"
):
if
not
hasattr
(
op
,
"nfunc_spec"
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
return
generate_fallback_impl
(
op
,
node
=
node
,
**
kwargs
),
None
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_numba
=
None
scalar_func_numba
=
None
...
@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
...
@@ -58,6 +62,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
output_inner_dtype
=
None
output_inner_dtype
=
None
# Cython functions might have an additional argument
# Cython functions might have an additional argument
cython_func
=
None
has_pyx_skip_dispatch
=
False
has_pyx_skip_dispatch
=
False
if
scalar_func_path
.
startswith
(
"scipy.special"
):
if
scalar_func_path
.
startswith
(
"scipy.special"
):
...
@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}):
...
@@ -127,20 +132,18 @@ def {scalar_op_fn_name}({", ".join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""
"""
scalar_op_fn
=
compile_function_src
(
scalar_op_fn
=
compile_numba_function_src
(
scalar_op_src
,
scalar_op_fn_name
,
{
**
globals
(),
**
global_env
}
scalar_op_src
,
scalar_op_fn_name
,
{
**
globals
(),
**
global_env
},
)
)
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
# Functions that call a function pointer can't be cached
cache_key
=
None
if
cython_func
else
scalar_op_cache_key
(
op
)
return
numba_basic
.
numba_njit
(
scalar_op_fn
),
cache_key
return
numba_basic
.
numba_njit
(
signature
,
# Functions that call a function pointer can't be cached
cache
=
False
,
)(
scalar_op_fn
)
@register_funcify_and_cache_key
(
Switch
)
@numba_funcify.register
(
Switch
)
def
numba_funcify_Switch
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Switch
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
switch
(
condition
,
x
,
y
):
def
switch
(
condition
,
x
,
y
):
...
@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs):
...
@@ -149,7 +152,7 @@ def numba_funcify_Switch(op, node, **kwargs):
else
:
else
:
return
y
return
y
return
switch
return
switch
,
scalar_op_cache_key
(
op
)
def
binary_to_nary_func
(
inputs
:
list
[
Variable
],
binary_op_name
:
str
,
binary_op
:
str
):
def
binary_to_nary_func
(
inputs
:
list
[
Variable
],
binary_op_name
:
str
,
binary_op
:
str
):
...
@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
...
@@ -163,28 +166,26 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
def {binary_op_name}({input_signature}):
def {binary_op_name}({input_signature}):
return {output_expr}
return {output_expr}
"""
"""
nary_fn
=
compile_function_src
(
nary_src
,
binary_op_name
,
globals
())
nary_fn
=
compile_
numba_
function_src
(
nary_src
,
binary_op_name
,
globals
())
return
nary_fn
return
nary_fn
@
numba_funcify.register
(
Add
)
@
register_funcify_and_cache_key
(
Add
)
def
numba_funcify_Add
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Add
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"add"
,
"+"
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"add"
,
"+"
)
return
numba_basic
.
numba_njit
(
signature
)(
nary_add_fn
)
return
numba_basic
.
numba_njit
(
nary_add_fn
),
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Mul
)
@
register_funcify_and_cache_key
(
Mul
)
def
numba_funcify_Mul
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Mul
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_mul_fn
=
binary_to_nary_func
(
node
.
inputs
,
"mul"
,
"*"
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"mul"
,
"*"
)
return
numba_basic
.
numba_njit
(
signature
)(
nary_add_fn
)
return
numba_basic
.
numba_njit
(
nary_mul_fn
),
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Cast
)
@
register_funcify_and_cache_key
(
Cast
)
def
numba_funcify_Cast
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Cast
(
op
,
node
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
o_type
.
dtype
)
dtype
=
np
.
dtype
(
op
.
o_type
.
dtype
)
...
@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs):
...
@@ -192,19 +193,19 @@ def numba_funcify_Cast(op, node, **kwargs):
def
cast
(
x
):
def
cast
(
x
):
return
numba_basic
.
direct_cast
(
x
,
dtype
)
return
numba_basic
.
direct_cast
(
x
,
dtype
)
return
cast
return
cast
,
sha256
(
str
((
type
(
op
),
op
.
o_type
.
dtype
))
.
encode
())
.
hexdigest
()
@
numba_funcify.register
(
Identity
)
@
register_funcify_and_cache_key
(
Identity
)
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
def
numba_funcify_type_casting
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
identity
(
x
):
def
identity
(
x
):
return
x
return
x
return
identity
return
identity
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Clip
)
@
register_funcify_and_cache_key
(
Clip
)
def
numba_funcify_Clip
(
op
,
**
kwargs
):
def
numba_funcify_Clip
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
clip
(
x
,
min_val
,
max_val
):
def
clip
(
x
,
min_val
,
max_val
):
...
@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs):
...
@@ -215,26 +216,33 @@ def numba_funcify_Clip(op, **kwargs):
else
:
else
:
return
x
return
x
return
clip
return
clip
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Composite
)
@
register_funcify_and_cache_key
(
Composite
)
def
numba_funcify_Composite
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Composite
(
op
,
node
,
**
kwargs
):
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
_
=
kwargs
.
pop
(
"storage_map"
,
None
)
return
numba_funcify
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
composite_fn
,
fgraph_key
=
numba_funcify_and_cache_key
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
if
fgraph_key
is
None
:
composite_key
=
None
else
:
composite_key
=
sha256
(
str
((
type
(
op
),
fgraph_key
))
.
encode
())
.
hexdigest
()
return
composite_fn
,
composite_key
@
numba_funcify.register
(
Second
)
@
register_funcify_and_cache_key
(
Second
)
def
numba_funcify_Second
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Second
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
second
(
x
,
y
):
def
second
(
x
,
y
):
return
y
return
y
return
second
return
second
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Reciprocal
)
@
register_funcify_and_cache_key
(
Reciprocal
)
def
numba_funcify_Reciprocal
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Reciprocal
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
reciprocal
(
x
):
def
reciprocal
(
x
):
...
@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
...
@@ -242,28 +250,28 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
# `x` is an `int`
# `x` is an `int`
return
1
/
x
return
1
/
x
return
reciprocal
return
reciprocal
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Sigmoid
)
@
register_funcify_and_cache_key
(
Sigmoid
)
def
numba_funcify_Sigmoid
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Sigmoid
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
sigmoid
(
x
):
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
return
1
/
(
1
+
np
.
exp
(
-
x
))
return
sigmoid
return
sigmoid
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
GammaLn
)
@
register_funcify_and_cache_key
(
GammaLn
)
def
numba_funcify_GammaLn
(
op
,
node
,
**
kwargs
):
def
numba_funcify_GammaLn
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
gammaln
(
x
):
def
gammaln
(
x
):
return
math
.
lgamma
(
x
)
return
math
.
lgamma
(
x
)
return
gammaln
return
gammaln
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Log1mexp
)
@
register_funcify_and_cache_key
(
Log1mexp
)
def
numba_funcify_Log1mexp
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Log1mexp
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
logp1mexp
(
x
):
def
logp1mexp
(
x
):
...
@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
...
@@ -272,28 +280,28 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
else
:
else
:
return
np
.
log
(
-
np
.
expm1
(
x
))
return
np
.
log
(
-
np
.
expm1
(
x
))
return
logp1mexp
return
logp1mexp
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Erf
)
@
register_funcify_and_cache_key
(
Erf
)
def
numba_funcify_Erf
(
op
,
**
kwargs
):
def
numba_funcify_Erf
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
erf
(
x
):
def
erf
(
x
):
return
math
.
erf
(
x
)
return
math
.
erf
(
x
)
return
erf
return
erf
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Erfc
)
@
register_funcify_and_cache_key
(
Erfc
)
def
numba_funcify_Erfc
(
op
,
**
kwargs
):
def
numba_funcify_Erfc
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
erfc
(
x
):
def
erfc
(
x
):
return
math
.
erfc
(
x
)
return
math
.
erfc
(
x
)
return
erfc
return
erfc
,
scalar_op_cache_key
(
op
)
@
numba_funcify.register
(
Softplus
)
@
register_funcify_and_cache_key
(
Softplus
)
def
numba_funcify_Softplus
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Softplus
(
op
,
node
,
**
kwargs
):
out_dtype
=
np
.
dtype
(
node
.
outputs
[
0
]
.
type
.
dtype
)
out_dtype
=
np
.
dtype
(
node
.
outputs
[
0
]
.
type
.
dtype
)
...
@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs):
...
@@ -309,4 +317,4 @@ def numba_funcify_Softplus(op, node, **kwargs):
value
=
x
value
=
x
return
numba_basic
.
direct_cast
(
value
,
out_dtype
)
return
numba_basic
.
direct_cast
(
value
,
out_dtype
)
return
softplus
return
softplus
,
scalar_op_cache_key
(
op
)
pytensor/link/numba/dispatch/scan.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
from
textwrap
import
dedent
,
indent
from
textwrap
import
dedent
,
indent
import
numpy
as
np
import
numpy
as
np
...
@@ -7,13 +8,14 @@ from numba.extending import overload
...
@@ -7,13 +8,14 @@ from numba.extending import overload
from
pytensor
import
In
from
pytensor
import
In
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.function.types
import
add_supervisor_to_fgraph
from
pytensor.compile.mode
import
NUMBA
,
get_mode
from
pytensor.compile.mode
import
NUMBA
,
get_mode
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
create_arg_string
,
create_arg_string
,
create_tuple_string
,
create_tuple_string
,
numba_funcify
,
numba_funcify_and_cache_key
,
register_funcify_and_cache_key
,
)
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
...
@@ -54,7 +56,7 @@ def array0d_range(x):
...
@@ -54,7 +56,7 @@ def array0d_range(x):
return
range_arr
return
range_arr
@
numba_funcify.register
(
Scan
)
@
register_funcify_and_cache_key
(
Scan
)
def
numba_funcify_Scan
(
op
:
Scan
,
node
,
**
kwargs
):
def
numba_funcify_Scan
(
op
:
Scan
,
node
,
**
kwargs
):
# Apply inner rewrites
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# TODO: Not sure this is the right place to do this, should we have a rewrite that
...
@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -97,7 +99,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
)
rewriter
(
fgraph
)
rewriter
(
fgraph
)
scan_inner_func
=
numba_funcif
y
(
op
.
fgraph
)
scan_inner_func
,
inner_func_cache_key
=
numba_funcify_and_cache_ke
y
(
op
.
fgraph
)
outer_in_names_to_vars
=
{
outer_in_names_to_vars
=
{
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
(
f
"outer_in_{i}"
if
i
>
0
else
"n_steps"
):
v
for
i
,
v
in
enumerate
(
node
.
inputs
)
...
@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}):
...
@@ -439,6 +441,18 @@ def scan({", ".join(outer_in_names)}):
"scan_inner_func"
:
scan_inner_func
,
"scan_inner_func"
:
scan_inner_func
,
}
}
scan_op_fn
=
compile_function_src
(
scan_op_src
,
"scan"
,
{
**
globals
(),
**
global_env
})
scan_op_fn
=
compile_numba_function_src
(
scan_op_src
,
"scan"
,
{
**
globals
(),
**
global_env
},
)
if
inner_func_cache_key
is
None
:
# If we can't cache the inner function, we can't cache the Scan either
scan_cache_key
=
None
else
:
scan_cache_key
=
sha256
(
f
"({scan_op_src}, {inner_func_cache_key})"
.
encode
()
)
.
hexdigest
()
return
numba_basic
.
numba_njit
(
scan_op_fn
,
boundscheck
=
False
)
return
numba_basic
.
numba_njit
(
scan_op_fn
,
boundscheck
=
False
)
,
scan_cache_key
pytensor/link/numba/dispatch/shape.py
浏览文件 @
6ac5ab28
...
@@ -4,14 +4,16 @@ import numpy as np
...
@@ -4,14 +4,16 @@ import numpy as np
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
numba.np.unsafe
import
ndarray
as
numba_ndarray
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
create_arg_string
,
numba_njit
create_arg_string
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.link.utils
import
compile_function_src
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor
import
NoneConst
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
@
numba_funcify.register
(
Shape
)
@
register_funcify_default_op_cache_key
(
Shape
)
def
numba_funcify_Shape
(
op
,
**
kwargs
):
def
numba_funcify_Shape
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
shape
(
x
):
def
shape
(
x
):
...
@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs):
...
@@ -20,7 +22,7 @@ def numba_funcify_Shape(op, **kwargs):
return
shape
return
shape
@
numba_funcify.register
(
Shape_i
)
@
register_funcify_default_op_cache_key
(
Shape_i
)
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
def
numba_funcify_Shape_i
(
op
,
**
kwargs
):
i
=
op
.
i
i
=
op
.
i
...
@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs):
...
@@ -31,7 +33,7 @@ def numba_funcify_Shape_i(op, **kwargs):
return
shape_i
return
shape_i
@
numba_funcify.register
(
SpecifyShape
)
@
register_funcify_default_op_cache_key
(
SpecifyShape
)
def
numba_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
shape_inputs
=
node
.
inputs
[
1
:]
shape_inputs
=
node
.
inputs
[
1
:]
shape_input_names
=
[
"shape_"
+
str
(
i
)
for
i
in
range
(
len
(
shape_inputs
))]
shape_input_names
=
[
"shape_"
+
str
(
i
)
for
i
in
range
(
len
(
shape_inputs
))]
...
@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
...
@@ -53,10 +55,10 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
)
)
specify_shape
=
compile_function_src
(
func
,
"specify_shape"
,
globals
())
specify_shape
=
compile_function_src
(
func
,
"specify_shape"
,
globals
())
return
numba_njit
(
specify_shape
)
return
numba_
basic
.
numba_
njit
(
specify_shape
)
@
numba_funcify.register
(
Reshape
)
@
register_funcify_default_op_cache_key
(
Reshape
)
def
numba_funcify_Reshape
(
op
,
**
kwargs
):
def
numba_funcify_Reshape
(
op
,
**
kwargs
):
ndim
=
op
.
ndim
ndim
=
op
.
ndim
...
...
pytensor/link/numba/dispatch/signal/conv.py
浏览文件 @
6ac5ab28
...
@@ -2,11 +2,13 @@ import numpy as np
...
@@ -2,11 +2,13 @@ import numpy as np
from
numba.np.arraymath
import
_get_inner_prod
from
numba.np.arraymath
import
_get_inner_prod
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
register_funcify_default_op_cache_key
,
)
from
pytensor.tensor.signal.conv
import
Convolve1d
from
pytensor.tensor.signal.conv
import
Convolve1d
@
numba_funcify.register
(
Convolve1d
)
@
register_funcify_default_op_cache_key
(
Convolve1d
)
def
numba_funcify_Convolve1d
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Convolve1d
(
op
,
node
,
**
kwargs
):
# This specialized version is faster than the overloaded numba np.convolve
# This specialized version is faster than the overloaded numba np.convolve
a_dtype
,
b_dtype
=
node
.
inputs
[
0
]
.
type
.
dtype
,
node
.
inputs
[
1
]
.
type
.
dtype
a_dtype
,
b_dtype
=
node
.
inputs
[
0
]
.
type
.
dtype
,
node
.
inputs
[
1
]
.
type
.
dtype
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
6ac5ab28
...
@@ -4,7 +4,10 @@ import numpy as np
...
@@ -4,7 +4,10 @@ import numpy as np
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
numba_funcify
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
_lu_1
,
_lu_1
,
...
@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
...
@@ -91,7 +94,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return
cholesky
return
cholesky
@
numba_funcify.register
(
PivotToPermutations
)
@
register_funcify_default_op_cache_key
(
PivotToPermutations
)
def
pivot_to_permutation
(
op
,
node
,
**
kwargs
):
def
pivot_to_permutation
(
op
,
node
,
**
kwargs
):
inverse
=
op
.
inverse
inverse
=
op
.
inverse
dtype
=
node
.
outputs
[
0
]
.
dtype
dtype
=
node
.
outputs
[
0
]
.
dtype
...
@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs):
...
@@ -119,7 +122,7 @@ def numba_funcify_LU(op, node, **kwargs):
if
dtype
in
complex_dtypes
:
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
lu
(
a
):
def
lu
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs):
...
@@ -181,11 +184,10 @@ def numba_funcify_LUFactor(op, node, **kwargs):
return
lu_factor
return
lu_factor
@
numba_funcify.register
(
BlockDiagonal
)
@
register_funcify_default_op_cache_key
(
BlockDiagonal
)
def
numba_funcify_BlockDiagonal
(
op
,
node
,
**
kwargs
):
def
numba_funcify_BlockDiagonal
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
outputs
[
0
]
.
dtype
dtype
=
node
.
outputs
[
0
]
.
dtype
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit
@numba_basic.numba_njit
def
block_diag
(
*
arrs
):
def
block_diag
(
*
arrs
):
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
shapes
=
np
.
array
([
a
.
shape
for
a
in
arrs
],
dtype
=
"int"
)
...
@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs):
...
@@ -338,7 +340,7 @@ def numba_funcify_QR(op, node, **kwargs):
integer_input
=
dtype
in
integer_dtypes
integer_input
=
dtype
in
integer_dtypes
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
in_dtype
=
config
.
floatX
if
integer_input
else
dtype
@numba_basic.numba_njit
(
cache
=
False
)
@numba_basic.numba_njit
def
qr
(
a
):
def
qr
(
a
):
if
check_finite
:
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
...
...
pytensor/link/numba/dispatch/sort.py
浏览文件 @
6ac5ab28
...
@@ -3,11 +3,13 @@ import warnings
...
@@ -3,11 +3,13 @@ import warnings
import
numpy
as
np
import
numpy
as
np
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
register_funcify_default_op_cache_key
,
)
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
from
pytensor.tensor.sort
import
ArgSortOp
,
SortOp
@
numba_funcify.register
(
SortOp
)
@
register_funcify_default_op_cache_key
(
SortOp
)
def
numba_funcify_SortOp
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SortOp
(
op
,
node
,
**
kwargs
):
if
op
.
kind
!=
"quicksort"
:
if
op
.
kind
!=
"quicksort"
:
warnings
.
warn
(
warnings
.
warn
(
...
@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
...
@@ -31,7 +33,7 @@ def numba_funcify_SortOp(op, node, **kwargs):
return
sort_f
return
sort_f
@
numba_funcify.register
(
ArgSortOp
)
@
register_funcify_default_op_cache_key
(
ArgSortOp
)
def
numba_funcify_ArgSortOp
(
op
,
node
,
**
kwargs
):
def
numba_funcify_ArgSortOp
(
op
,
node
,
**
kwargs
):
kind
=
op
.
kind
kind
=
op
.
kind
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
6ac5ab28
import
operator
import
operator
import
sys
import
sys
from
hashlib
import
sha256
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
...
@@ -7,11 +8,17 @@ from llvmlite import ir
...
@@ -7,11 +8,17 @@ from llvmlite import ir
from
numba
import
types
from
numba
import
types
from
numba.core.pythonapi
import
box
from
numba.core.pythonapi
import
box
import
pytensor.link.numba.dispatch.basic
as
numba_basic
from
pytensor.graph
import
Type
from
pytensor.graph
import
Type
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.cache
import
(
from
pytensor.link.numba.dispatch
import
numba_funcify
compile_numba_function_src
,
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
)
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.utils
import
unique_name_generator
from
pytensor.tensor
import
TensorType
from
pytensor.tensor
import
TensorType
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
...
@@ -98,7 +105,7 @@ def enable_slice_boxing():
...
@@ -98,7 +105,7 @@ def enable_slice_boxing():
enable_slice_boxing
()
enable_slice_boxing
()
@
numba_funcify.register
(
MakeSlice
)
@
register_funcify_default_op_cache_key
(
MakeSlice
)
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
def
numba_funcify_MakeSlice
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
makeslice
(
*
x
):
def
makeslice
(
*
x
):
...
@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs):
...
@@ -107,9 +114,32 @@ def numba_funcify_MakeSlice(op, **kwargs):
return
makeslice
return
makeslice
@numba_funcify.register
(
Subtensor
)
def
subtensor_op_cache_key
(
op
,
**
extra_fields
):
@numba_funcify.register
(
IncSubtensor
)
key_parts
=
[
type
(
op
),
tuple
(
extra_fields
.
items
())]
@numba_funcify.register
(
AdvancedSubtensor1
)
if
hasattr
(
op
,
"idx_list"
):
idx_parts
=
[]
for
idx
in
op
.
idx_list
:
if
isinstance
(
idx
,
slice
):
idx_parts
.
append
(
(
idx
.
start
is
None
,
idx
.
stop
is
None
,
idx
.
step
is
None
,
)
)
else
:
idx_parts
.
append
(
"i"
)
key_parts
.
append
(
tuple
(
idx_parts
))
if
isinstance
(
op
,
IncSubtensor
|
AdvancedIncSubtensor
|
AdvancedIncSubtensor1
):
key_parts
.
append
((
op
.
inplace
,
op
.
set_instead_of_inc
))
if
isinstance
(
op
,
AdvancedIncSubtensor
):
key_parts
.
append
(
op
.
ignore_duplicates
)
return
sha256
(
str
(
tuple
(
key_parts
))
.
encode
())
.
hexdigest
()
@register_funcify_and_cache_key
(
Subtensor
)
@register_funcify_and_cache_key
(
IncSubtensor
)
@register_funcify_and_cache_key
(
AdvancedSubtensor1
)
def
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
):
def
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
):
"""Create a Python function that assembles and uses an index on an array."""
"""Create a Python function that assembles and uses an index on an array."""
...
@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}):
...
@@ -185,16 +215,17 @@ def {function_name}({", ".join(input_names)}):
return np.asarray(z)
return np.asarray(z)
"""
"""
func
=
compile_function_src
(
func
=
compile_
numba_
function_src
(
subtensor_def_src
,
subtensor_def_src
,
function_name
=
function_name
,
function_name
=
function_name
,
global_env
=
globals
()
|
{
"np"
:
np
},
global_env
=
globals
()
|
{
"np"
:
np
},
)
)
return
numba_njit
(
func
,
boundscheck
=
True
)
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"numba_funcify_default_subtensor"
)
return
numba_basic
.
numba_njit
(
func
,
boundscheck
=
True
),
cache_key
@
numba_funcify.register
(
AdvancedSubtensor
)
@
register_funcify_and_cache_key
(
AdvancedSubtensor
)
@
numba_funcify.register
(
AdvancedIncSubtensor
)
@
register_funcify_and_cache_key
(
AdvancedIncSubtensor
)
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
if
isinstance
(
op
,
AdvancedSubtensor
):
if
isinstance
(
op
,
AdvancedSubtensor
):
_x
,
_y
,
idxs
=
node
.
inputs
[
0
],
None
,
node
.
inputs
[
1
:]
_x
,
_y
,
idxs
=
node
.
inputs
[
0
],
None
,
node
.
inputs
[
1
:]
...
@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
...
@@ -255,7 +286,9 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
)
)
)
):
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
),
subtensor_op_cache_key
(
op
,
func
=
"fallback_impl"
)
# What's left should all be supported natively by numba
# What's left should all be supported natively by numba
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
...
@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -295,6 +328,7 @@ def numba_funcify_multiple_integer_vector_indexing(
vector_indices
=
idxs
[
first_axis
:
after_last_axis
]
vector_indices
=
idxs
[
first_axis
:
after_last_axis
]
assert
all
(
v
.
type
.
broadcastable
==
(
False
,)
for
v
in
vector_indices
)
assert
all
(
v
.
type
.
broadcastable
==
(
False
,)
for
v
in
vector_indices
)
y_is_broadcasted
=
False
if
isinstance
(
op
,
AdvancedSubtensor
):
if
isinstance
(
op
,
AdvancedSubtensor
):
...
@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -313,7 +347,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
return
out_buffer
return
out_buffer
ret
urn
advanced_subtensor_multiple_vector
ret
_func
=
advanced_subtensor_multiple_vector
else
:
else
:
inplace
=
op
.
inplace
inplace
=
op
.
inplace
...
@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -347,7 +381,7 @@ def numba_funcify_multiple_integer_vector_indexing(
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
return
out
return
out
ret
urn
advanced_set_subtensor_multiple_vector
ret
_func
=
advanced_set_subtensor_multiple_vector
else
:
else
:
...
@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing(
...
@@ -369,10 +403,17 @@ def numba_funcify_multiple_integer_vector_indexing(
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
return
out
return
out
return
advanced_inc_subtensor_multiple_vector
ret_func
=
advanced_inc_subtensor_multiple_vector
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"multiple_integer_vector_indexing"
,
y_is_broadcasted
=
y_is_broadcasted
,
)
return
ret_func
,
cache_key
@
numba_funcify.register
(
AdvancedIncSubtensor1
)
@
register_funcify_and_cache_key
(
AdvancedIncSubtensor1
)
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
inplace
=
op
.
inplace
set_instead_of_inc
=
op
.
set_instead_of_inc
set_instead_of_inc
=
op
.
set_instead_of_inc
...
@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -436,8 +477,14 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x
[
idx
]
+=
val
x
[
idx
]
+=
val
return
x
return
x
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"numba_funcify_advancedincsubtensor1"
,
broadcast_with_index
=
broadcast_with_index
,
)
if
inplace
:
if
inplace
:
return
advancedincsubtensor1_inplace
return
advancedincsubtensor1_inplace
,
cache_key
else
:
else
:
...
@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
...
@@ -446,4 +493,4 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x
=
x
.
copy
()
x
=
x
.
copy
()
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
return
advancedincsubtensor1
return
advancedincsubtensor1
,
cache_key
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
6ac5ab28
from
hashlib
import
sha256
from
textwrap
import
indent
from
textwrap
import
indent
import
numpy
as
np
import
numpy
as
np
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
(
from
pytensor.link.numba.dispatch.basic
import
(
create_tuple_string
,
create_tuple_string
,
numba_funcify
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
)
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.utils
import
unique_name_generator
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
Alloc
,
Alloc
,
AllocEmpty
,
AllocEmpty
,
...
@@ -23,7 +26,7 @@ from pytensor.tensor.basic import (
...
@@ -23,7 +26,7 @@ from pytensor.tensor.basic import (
)
)
@
numba_funcify.register
(
AllocEmpty
)
@
register_funcify_default_op_cache_key
(
AllocEmpty
)
def
numba_funcify_AllocEmpty
(
op
,
node
,
**
kwargs
):
def
numba_funcify_AllocEmpty
(
op
,
node
,
**
kwargs
):
global_env
=
{
global_env
=
{
"np"
:
np
,
"np"
:
np
,
...
@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}):
...
@@ -52,14 +55,14 @@ def allocempty({", ".join(shape_var_names)}):
return np.empty(scalar_shape, dtype)
return np.empty(scalar_shape, dtype)
"""
"""
alloc_fn
=
compile_function_src
(
alloc_fn
=
compile_
numba_
function_src
(
alloc_def_src
,
"allocempty"
,
{
**
globals
(),
**
global_env
}
alloc_def_src
,
"allocempty"
,
{
**
globals
(),
**
global_env
}
)
)
return
numba_basic
.
numba_njit
(
alloc_fn
)
return
numba_basic
.
numba_njit
(
alloc_fn
)
@
numba_funcify.register
(
Alloc
)
@
register_funcify_and_cache_key
(
Alloc
)
def
numba_funcify_Alloc
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Alloc
(
op
,
node
,
**
kwargs
):
global_env
=
{
"np"
:
np
}
global_env
=
{
"np"
:
np
}
...
@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}):
...
@@ -96,16 +99,23 @@ def alloc(val, {", ".join(shape_var_names)}):
res[...] = val
res[...] = val
return res
return res
"""
"""
alloc_fn
=
compile_function_src
(
alloc_def_src
,
"alloc"
,
{
**
globals
(),
**
global_env
})
alloc_fn
=
compile_numba_function_src
(
alloc_def_src
,
"alloc"
,
{
**
globals
(),
**
global_env
},
)
return
numba_basic
.
numba_njit
(
alloc_fn
)
cache_key
=
sha256
(
str
((
type
(
op
),
node
.
inputs
[
0
]
.
type
.
broadcastable
))
.
encode
()
)
.
hexdigest
()
return
numba_basic
.
numba_njit
(
alloc_fn
),
cache_key
@
numba_funcify.register
(
ARange
)
@
register_funcify_default_op_cache_key
(
ARange
)
def
numba_funcify_ARange
(
op
,
**
kwargs
):
def
numba_funcify_ARange
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
arange
(
start
,
stop
,
step
):
def
arange
(
start
,
stop
,
step
):
return
np
.
arange
(
return
np
.
arange
(
start
.
item
(),
start
.
item
(),
...
@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs):
...
@@ -117,7 +127,7 @@ def numba_funcify_ARange(op, **kwargs):
return
arange
return
arange
@
numba_funcify.register
(
Join
)
@
register_funcify_default_op_cache_key
(
Join
)
def
numba_funcify_Join
(
op
,
**
kwargs
):
def
numba_funcify_Join
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
join
(
axis
,
*
tensors
):
def
join
(
axis
,
*
tensors
):
...
@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs):
...
@@ -126,7 +136,7 @@ def numba_funcify_Join(op, **kwargs):
return
join
return
join
@
numba_funcify.register
(
Split
)
@
register_funcify_default_op_cache_key
(
Split
)
def
numba_funcify_Split
(
op
,
**
kwargs
):
def
numba_funcify_Split
(
op
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
split
(
tensor
,
axis
,
indices
):
def
split
(
tensor
,
axis
,
indices
):
...
@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs):
...
@@ -135,14 +145,14 @@ def numba_funcify_Split(op, **kwargs):
return
split
return
split
@
numba_funcify.register
(
ExtractDiag
)
@
register_funcify_default_op_cache_key
(
ExtractDiag
)
def
numba_funcify_ExtractDiag
(
op
,
node
,
**
kwargs
):
def
numba_funcify_ExtractDiag
(
op
,
node
,
**
kwargs
):
view
=
op
.
view
view
=
op
.
view
axis1
,
axis2
,
offset
=
op
.
axis1
,
op
.
axis2
,
op
.
offset
axis1
,
axis2
,
offset
=
op
.
axis1
,
op
.
axis2
,
op
.
offset
if
node
.
inputs
[
0
]
.
type
.
ndim
==
2
:
if
node
.
inputs
[
0
]
.
type
.
ndim
==
2
:
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
extract_diag
(
x
):
def
extract_diag
(
x
):
out
=
np
.
diag
(
x
,
k
=
offset
)
out
=
np
.
diag
(
x
,
k
=
offset
)
...
@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
...
@@ -157,7 +167,7 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
leading_dims
=
(
slice
(
None
),)
*
axis1
leading_dims
=
(
slice
(
None
),)
*
axis1
middle_dims
=
(
slice
(
None
),)
*
(
axis2
-
axis1
-
1
)
middle_dims
=
(
slice
(
None
),)
*
(
axis2
-
axis1
-
1
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
extract_diag
(
x
):
def
extract_diag
(
x
):
if
offset
>=
0
:
if
offset
>=
0
:
diag_len
=
min
(
x
.
shape
[
axis1
],
max
(
0
,
x
.
shape
[
axis2
]
-
offset
))
diag_len
=
min
(
x
.
shape
[
axis1
],
max
(
0
,
x
.
shape
[
axis2
]
-
offset
))
...
@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
...
@@ -178,11 +188,11 @@ def numba_funcify_ExtractDiag(op, node, **kwargs):
return
extract_diag
return
extract_diag
@
numba_funcify.register
(
Eye
)
@
register_funcify_default_op_cache_key
(
Eye
)
def
numba_funcify_Eye
(
op
,
**
kwargs
):
def
numba_funcify_Eye
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
eye
(
N
,
M
,
k
):
def
eye
(
N
,
M
,
k
):
return
np
.
eye
(
return
np
.
eye
(
N
.
item
(),
N
.
item
(),
...
@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs):
...
@@ -194,7 +204,7 @@ def numba_funcify_Eye(op, **kwargs):
return
eye
return
eye
@
numba_funcify.register
(
MakeVector
)
@
register_funcify_default_op_cache_key
(
MakeVector
)
def
numba_funcify_MakeVector
(
op
,
node
,
**
kwargs
):
def
numba_funcify_MakeVector
(
op
,
node
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
dtype
=
np
.
dtype
(
op
.
dtype
)
...
@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}):
...
@@ -215,32 +225,34 @@ def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=dtype)
return np.array({create_list_string(input_names)}, dtype=dtype)
"""
"""
makevector_fn
=
compile_function_src
(
makevector_fn
=
compile_numba_function_src
(
makevector_def_src
,
"makevector"
,
{
**
globals
(),
**
global_env
}
makevector_def_src
,
"makevector"
,
{
**
globals
(),
**
global_env
},
)
)
return
numba_basic
.
numba_njit
(
makevector_fn
)
return
numba_basic
.
numba_njit
(
makevector_fn
)
@
numba_funcify.register
(
TensorFromScalar
)
@
register_funcify_default_op_cache_key
(
TensorFromScalar
)
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
tensor_from_scalar
(
x
):
def
tensor_from_scalar
(
x
):
return
np
.
array
(
x
)
return
np
.
array
(
x
)
return
tensor_from_scalar
return
tensor_from_scalar
@
numba_funcify.register
(
ScalarFromTensor
)
@
register_funcify_default_op_cache_key
(
ScalarFromTensor
)
def
numba_funcify_ScalarFromTensor
(
op
,
**
kwargs
):
def
numba_funcify_ScalarFromTensor
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
def
scalar_from_tensor
(
x
):
def
scalar_from_tensor
(
x
):
return
x
.
item
()
return
x
.
item
()
return
scalar_from_tensor
return
scalar_from_tensor
@
numba_funcify.register
(
Nonzero
)
@
register_funcify_default_op_cache_key
(
Nonzero
)
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Nonzero
(
op
,
node
,
**
kwargs
):
@numba_basic.numba_njit
@numba_basic.numba_njit
def
nonzero
(
a
):
def
nonzero
(
a
):
...
...
pytensor/link/numba/dispatch/vectorize_codegen.py
浏览文件 @
6ac5ab28
...
@@ -4,7 +4,7 @@ import base64
...
@@ -4,7 +4,7 @@ import base64
import
pickle
import
pickle
from
collections.abc
import
Callable
,
Sequence
from
collections.abc
import
Callable
,
Sequence
from
textwrap
import
indent
from
textwrap
import
indent
from
typing
import
Any
,
cast
from
typing
import
Any
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
...
@@ -15,8 +15,8 @@ from numba.core.base import BaseContext
...
@@ -15,8 +15,8 @@ from numba.core.base import BaseContext
from
numba.core.types.misc
import
NoneType
from
numba.core.types.misc
import
NoneType
from
numba.np
import
arrayobj
from
numba.np
import
arrayobj
from
pytensor.link.numba.cache
import
compile_numba_function_src
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.utils
import
compile_function_src
def
encode_literals
(
literals
:
Sequence
)
->
str
:
def
encode_literals
(
literals
:
Sequence
)
->
str
:
...
@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}):
...
@@ -52,10 +52,13 @@ def store_core_outputs({inp_signature}, {out_signature}):
{indent(store_outputs, " " * 4)}
{indent(store_outputs, " " * 4)}
"""
"""
global_env
=
{
"core_op_fn"
:
core_op_fn
}
global_env
=
{
"core_op_fn"
:
core_op_fn
}
func
=
compile_function_src
(
func_src
,
"store_core_outputs"
,
{
**
globals
(),
**
global_env
}
func
=
compile_numba_function_src
(
func_src
,
"store_core_outputs"
,
{
**
globals
(),
**
global_env
},
)
)
return
cast
(
Callable
,
numba_basic
.
numba_njit
(
func
)
)
return
numba_basic
.
numba_njit
(
func
)
_jit_options
=
{
_jit_options
=
{
...
@@ -74,7 +77,7 @@ _jit_options = {
...
@@ -74,7 +77,7 @@ _jit_options = {
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
def
_vectorized
(
def
_vectorized
(
typingctx
,
typingctx
,
scalar
_func
,
core
_func
,
input_bc_patterns
,
input_bc_patterns
,
output_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
output_dtypes
,
...
@@ -85,7 +88,7 @@ def _vectorized(
...
@@ -85,7 +88,7 @@ def _vectorized(
size_type
,
size_type
,
):
):
arg_types
=
[
arg_types
=
[
scalar
_func
,
core
_func
,
input_bc_patterns
,
input_bc_patterns
,
output_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
output_dtypes
,
...
@@ -173,16 +176,6 @@ def _vectorized(
...
@@ -173,16 +176,6 @@ def _vectorized(
)
)
out_types
[
output_idx
]
=
output_type
out_types
[
output_idx
]
=
output_type
core_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
*
constant_inputs_types
,
*
core_input_types
,
*
core_out_types
,
],
{},
)
ret_type
=
types
.
Tuple
(
out_types
)
ret_type
=
types
.
Tuple
(
out_types
)
if
len
(
output_dtypes
)
==
1
:
if
len
(
output_dtypes
)
==
1
:
...
@@ -239,11 +232,21 @@ def _vectorized(
...
@@ -239,11 +232,21 @@ def _vectorized(
output_core_shapes
,
output_core_shapes
,
)
)
core_signature
=
typingctx
.
resolve_function_type
(
core_func
,
[
*
constant_inputs_types
,
*
core_input_types
,
*
core_out_types
,
],
{},
)
make_loop_call
(
make_loop_call
(
typingctx
,
typingctx
,
ctx
,
ctx
,
builder
,
builder
,
scalar
_func
,
core
_func
,
core_signature
,
core_signature
,
iter_shape
,
iter_shape
,
constant_inputs
,
constant_inputs
,
...
@@ -416,8 +419,8 @@ def make_loop_call(
...
@@ -416,8 +419,8 @@ def make_loop_call(
typingctx
,
typingctx
,
context
:
numba
.
core
.
base
.
BaseContext
,
context
:
numba
.
core
.
base
.
BaseContext
,
builder
:
ir
.
IRBuilder
,
builder
:
ir
.
IRBuilder
,
scalar
_func
:
Any
,
core
_func
:
Any
,
scalar
_signature
:
types
.
FunctionType
,
core
_signature
:
types
.
FunctionType
,
iter_shape
:
tuple
[
ir
.
Instruction
,
...
],
iter_shape
:
tuple
[
ir
.
Instruction
,
...
],
constant_inputs
:
tuple
[
ir
.
Instruction
,
...
],
constant_inputs
:
tuple
[
ir
.
Instruction
,
...
],
inputs
:
tuple
[
ir
.
Instruction
,
...
],
inputs
:
tuple
[
ir
.
Instruction
,
...
],
...
@@ -557,10 +560,10 @@ def make_loop_call(
...
@@ -557,10 +560,10 @@ def make_loop_call(
val
=
core_array
.
_getvalue
()
val
=
core_array
.
_getvalue
()
output_slices
.
append
(
val
)
output_slices
.
append
(
val
)
inner_codegen
=
context
.
get_function
(
scalar_func
,
scalar
_signature
)
inner_codegen
=
context
.
get_function
(
core_func
,
core
_signature
)
if
isinstance
(
scalar
_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
if
isinstance
(
core
_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
input_vals
=
[
context
.
make_tuple
(
builder
,
scalar
_signature
.
args
[
0
],
input_vals
)]
input_vals
=
[
context
.
make_tuple
(
builder
,
core
_signature
.
args
[
0
],
input_vals
)]
inner_codegen
(
builder
,
[
*
constant_inputs
,
*
input_vals
,
*
output_slices
])
inner_codegen
(
builder
,
[
*
constant_inputs
,
*
input_vals
,
*
output_slices
])
...
...
tests/link/numba/signal/test_conv.py
浏览文件 @
6ac5ab28
...
@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
...
@@ -13,21 +13,32 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
@pytest.mark.parametrize
(
"bcast_order"
,
(
1
,
0
))
@pytest.mark.parametrize
(
"mode"
,
[
"full"
,
"valid"
,
"same"
])
@pytest.mark.parametrize
(
"mode"
,
[
"full"
,
"valid"
,
"same"
])
@pytest.mark.parametrize
(
"x_smaller"
,
(
False
,
True
))
def
test_convolve1d
(
mode
,
bcast_order
):
def
test_convolve1d
(
x_smaller
,
mode
):
x
=
dmatrix
(
"x"
)
x
=
dmatrix
(
"x"
)
y
=
dmatrix
(
"y"
)
y
=
dmatrix
(
"y"
)
if
x_smaller
:
# Testing two orders because this revealed a bug in the past
out
=
convolve1d
(
x
[
None
],
y
[:,
None
],
mode
=
mode
)
if
bcast_order
==
0
:
out
=
convolve1d
(
x
[:,
None
],
y
[
None
,
:],
mode
=
mode
)
else
:
else
:
out
=
convolve1d
(
y
[:,
None
],
x
[
None
],
mode
=
mode
)
out
=
convolve1d
(
x
[
None
],
y
[:,
None
],
mode
=
mode
)
rng
=
np
.
random
.
default_rng
()
rng
=
np
.
random
.
default_rng
()
test_x
=
rng
.
normal
(
size
=
(
3
,
5
))
test_x
=
rng
.
normal
(
size
=
(
3
,
5
))
test_y
=
rng
.
normal
(
size
=
(
7
,
11
))
test_y
=
rng
.
normal
(
size
=
(
7
,
11
))
# Blockwise dispatch for numba can't be run on object mode
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py
([
x
,
y
],
out
,
[
test_x
,
test_y
],
eval_obj_mode
=
False
)
numba_fn
,
res
=
compare_numba_and_py
(
[
x
,
y
],
out
,
[
test_x
,
test_y
],
eval_obj_mode
=
False
)
# Try other order of inputs, as implementation depends on it
# Result should be the same, just in different order, except for 'same' mode
if
mode
!=
"same"
:
np
.
testing
.
assert_allclose
(
np
.
swapaxes
(
numba_fn
(
test_y
,
test_x
),
0
,
1
),
res
,
)
@pytest.mark.parametrize
(
"mode"
,
(
"full"
,
"valid"
),
ids
=
lambda
x
:
f
"mode={x}"
)
@pytest.mark.parametrize
(
"mode"
,
(
"full"
,
"valid"
),
ids
=
lambda
x
:
f
"mode={x}"
)
...
...
tests/link/numba/test_basic.py
浏览文件 @
6ac5ab28
...
@@ -402,7 +402,9 @@ def test_config_options_fastmath():
...
@@ -402,7 +402,9 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"jitable_func"
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
==
{
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
==
{
"afn"
,
"afn"
,
"arcp"
,
"arcp"
,
...
@@ -413,7 +415,9 @@ def test_config_options_fastmath():
...
@@ -413,7 +415,9 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
False
):
with
config
.
change_flags
(
numba__fastmath
=
False
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"jitable_func"
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
is
False
assert
numba_sum_fn
.
targetoptions
[
"fastmath"
]
is
False
...
@@ -422,9 +426,10 @@ def test_config_options_cached():
...
@@ -422,9 +426,10 @@ def test_config_options_cached():
with
config
.
change_flags
(
numba__cache
=
True
):
with
config
.
change_flags
(
numba__cache
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_sum_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
# Caching is disabled unless the dispatched function returns an explicit cache key
"jitable_func"
assert
isinstance
(
numba_sum_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
]
.
py_func
.
__globals__
[
"impl_sum"
]
assert
not
isinstance
(
numba_sum_fn
.
_cache
,
numba
.
core
.
caching
.
NullCache
)
with
config
.
change_flags
(
numba__cache
=
False
):
with
config
.
change_flags
(
numba__cache
=
False
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论