Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
98be9c5f
提交
98be9c5f
authored
7月 27, 2022
作者:
Adrian Seyboldt
提交者:
Adrian Seyboldt
12月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace numba_scipy
上级
1c507090
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
370 行增加
和
73 行删除
+370
-73
configdefaults.py
pytensor/configdefaults.py
+0
-6
cython_support.py
pytensor/link/numba/dispatch/cython_support.py
+211
-0
scalar.py
pytensor/link/numba/dispatch/scalar.py
+67
-67
test_cython_support.py
tests/link/numba/test_cython_support.py
+92
-0
没有找到文件。
pytensor/configdefaults.py
浏览文件 @
98be9c5f
...
...
@@ -1252,12 +1252,6 @@ def add_numba_configvars():
BoolParam
(
True
),
in_c_key
=
False
,
)
config
.
add
(
"numba_scipy"
,
(
"Enable usage of the numba_scipy package for special functions"
,),
BoolParam
(
True
),
in_c_key
=
False
,
)
def
_default_compiledirname
():
...
...
pytensor/link/numba/dispatch/cython_support.py
0 → 100644
浏览文件 @
98be9c5f
import
ctypes
import
importlib
import
re
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
cast
import
numba
import
numpy
as
np
from
numpy.typing
import
DTypeLike
from
scipy
import
LowLevelCallable
_C_TO_NUMPY
:
Dict
[
str
,
DTypeLike
]
=
{
"bool"
:
np
.
bool_
,
"signed char"
:
np
.
byte
,
"unsigned char"
:
np
.
ubyte
,
"short"
:
np
.
short
,
"unsigned short"
:
np
.
ushort
,
"int"
:
np
.
intc
,
"unsigned int"
:
np
.
uintc
,
"long"
:
np
.
int_
,
"unsigned long"
:
np
.
uint
,
"long long"
:
np
.
longlong
,
"float"
:
np
.
single
,
"double"
:
np
.
double
,
"long double"
:
np
.
longdouble
,
"float complex"
:
np
.
csingle
,
"double complex"
:
np
.
cdouble
,
}
@dataclass
class
Signature
:
res_dtype
:
DTypeLike
res_c_type
:
str
arg_dtypes
:
List
[
DTypeLike
]
arg_c_types
:
List
[
str
]
arg_names
:
List
[
Optional
[
str
]]
@property
def
arg_numba_types
(
self
)
->
List
[
DTypeLike
]:
return
[
numba
.
from_dtype
(
dtype
)
for
dtype
in
self
.
arg_dtypes
]
def
can_cast_args
(
self
,
args
:
List
[
DTypeLike
])
->
bool
:
ok
=
True
count
=
0
for
name
,
dtype
in
zip
(
self
.
arg_names
,
self
.
arg_dtypes
):
if
name
==
"__pyx_skip_dispatch"
:
continue
if
len
(
args
)
<=
count
:
raise
ValueError
(
"Incorrect number of arguments"
)
ok
&=
np
.
can_cast
(
args
[
count
],
dtype
)
count
+=
1
if
count
!=
len
(
args
):
return
False
return
ok
def
provides
(
self
,
restype
:
DTypeLike
,
arg_dtypes
:
List
[
DTypeLike
])
->
bool
:
args_ok
=
self
.
can_cast_args
(
arg_dtypes
)
if
np
.
issubdtype
(
restype
,
np
.
inexact
):
result_ok
=
np
.
can_cast
(
self
.
res_dtype
,
restype
,
casting
=
"same_kind"
)
# We do not want to provide less accuracy than advertised
result_ok
&=
np
.
dtype
(
self
.
res_dtype
)
.
itemsize
>=
np
.
dtype
(
restype
)
.
itemsize
else
:
result_ok
=
np
.
can_cast
(
self
.
res_dtype
,
restype
)
return
args_ok
and
result_ok
@staticmethod
def
from_c_types
(
signature
:
bytes
)
->
"Signature"
:
# Match strings like "double(int, double)"
# and extract the return type and the joined arguments
expr
=
re
.
compile
(
rb
"
\
s*(?P<restype>[
\
w ]*
\
w+)
\
s*
\
((?P<args>[
\
w
\
s,]*)
\
)"
)
re_match
=
re
.
fullmatch
(
expr
,
signature
)
if
re_match
is
None
:
raise
ValueError
(
f
"Invalid signature: {signature.decode()}"
)
groups
=
re_match
.
groupdict
()
res_c_type
=
groups
[
"restype"
]
.
decode
()
res_dtype
:
DTypeLike
=
_C_TO_NUMPY
[
res_c_type
]
raw_args
=
groups
[
"args"
]
decl_expr
=
re
.
compile
(
rb
"
\
s*(?P<type>((long )|(unsigned )|(signed )|(double )|)"
rb
"((double)|(float)|(int)|(short)|(char)|(long)|(bool)|(complex)))"
rb
"(
\
s(?P<name>[
\
w_]*))?
\
s*"
)
arg_dtypes
=
[]
arg_names
:
List
[
Optional
[
str
]]
=
[]
arg_c_types
=
[]
for
raw_arg
in
raw_args
.
split
(
b
","
):
re_match
=
re
.
fullmatch
(
decl_expr
,
raw_arg
)
if
re_match
is
None
:
raise
ValueError
(
f
"Invalid signature: {signature.decode()}"
)
groups
=
re_match
.
groupdict
()
arg_c_type
=
groups
[
"type"
]
.
decode
()
try
:
arg_dtype
=
_C_TO_NUMPY
[
arg_c_type
]
except
KeyError
:
raise
ValueError
(
f
"Unknown C type: {arg_c_type}"
)
arg_c_types
.
append
(
arg_c_type
)
arg_dtypes
.
append
(
arg_dtype
)
name
=
groups
[
"name"
]
if
not
name
:
arg_names
.
append
(
None
)
else
:
arg_names
.
append
(
name
.
decode
())
return
Signature
(
res_dtype
,
res_c_type
,
arg_dtypes
,
arg_c_types
,
arg_names
)
def
_available_impls
(
func
:
Callable
)
->
List
[
Tuple
[
Signature
,
Any
]]:
"""Find all available implementations for a fused cython function."""
impls
=
[]
mod
=
importlib
.
import_module
(
func
.
__module__
)
signatures
=
getattr
(
func
,
"__signatures__"
,
None
)
if
signatures
is
not
None
:
# Cython function with __signatures__ should be fused and thus
# indexable
func_map
=
cast
(
Mapping
,
func
)
candidates
=
[
func_map
[
key
]
for
key
in
signatures
]
else
:
candidates
=
[
func
]
for
candidate
in
candidates
:
name
=
candidate
.
__name__
capsule
=
mod
.
__pyx_capi__
[
name
]
llc
=
LowLevelCallable
(
capsule
)
try
:
signature
=
Signature
.
from_c_types
(
llc
.
signature
.
encode
())
except
KeyError
:
continue
impls
.
append
((
signature
,
capsule
))
return
impls
class
_CythonWrapper
(
numba
.
types
.
WrapperAddressProtocol
):
def
__init__
(
self
,
pyfunc
,
signature
,
capsule
):
self
.
_keep_alive
=
capsule
get_name
=
ctypes
.
pythonapi
.
PyCapsule_GetName
get_name
.
restype
=
ctypes
.
c_char_p
get_name
.
argtypes
=
(
ctypes
.
py_object
,)
raw_signature
=
get_name
(
capsule
)
get_pointer
=
ctypes
.
pythonapi
.
PyCapsule_GetPointer
get_pointer
.
restype
=
ctypes
.
c_void_p
get_pointer
.
argtypes
=
(
ctypes
.
py_object
,
ctypes
.
c_char_p
)
self
.
_func_ptr
=
get_pointer
(
capsule
,
raw_signature
)
self
.
_signature
=
signature
self
.
_pyfunc
=
pyfunc
def
signature
(
self
):
return
numba
.
from_dtype
(
self
.
_signature
.
res_dtype
)(
*
self
.
_signature
.
arg_numba_types
)
def
__wrapper_address__
(
self
):
return
self
.
_func_ptr
def
__call__
(
self
,
*
args
,
**
kwargs
):
args
=
[
dtype
(
arg
)
for
arg
,
dtype
in
zip
(
args
,
self
.
_signature
.
arg_dtypes
)]
if
self
.
has_pyx_skip_dispatch
():
output
=
self
.
_pyfunc
(
*
args
[:
-
1
],
**
kwargs
)
else
:
output
=
self
.
_pyfunc
(
*
args
,
**
kwargs
)
return
self
.
_signature
.
res_dtype
(
output
)
def
has_pyx_skip_dispatch
(
self
):
if
not
self
.
_signature
.
arg_names
:
return
False
if
any
(
name
==
"__pyx_skip_dispatch"
for
name
in
self
.
_signature
.
arg_names
[:
-
1
]
):
raise
ValueError
(
"skip_dispatch parameter must be last"
)
return
self
.
_signature
.
arg_names
[
-
1
]
==
"__pyx_skip_dispatch"
def
numpy_arg_dtypes
(
self
):
return
self
.
_signature
.
arg_dtypes
def
numpy_output_dtype
(
self
):
return
self
.
_signature
.
res_dtype
def
wrap_cython_function
(
func
,
restype
,
arg_types
):
impls
=
_available_impls
(
func
)
compatible
=
[]
for
sig
,
capsule
in
impls
:
if
sig
.
provides
(
restype
,
arg_types
):
compatible
.
append
((
sig
,
capsule
))
def
sort_key
(
args
):
sig
,
_
=
args
# Prefer functions with less inputs bytes
argsize
=
sum
(
np
.
dtype
(
dtype
)
.
itemsize
for
dtype
in
sig
.
arg_dtypes
)
# Prefer functions with more exact (integer) arguments
num_inexact
=
sum
(
np
.
issubdtype
(
dtype
,
np
.
inexact
)
for
dtype
in
sig
.
arg_dtypes
)
return
(
num_inexact
,
argsize
)
compatible
.
sort
(
key
=
sort_key
)
if
not
compatible
:
raise
NotImplementedError
(
f
"Could not find a compatible impl of {func}"
)
sig
,
capsule
=
compatible
[
0
]
return
_CythonWrapper
(
func
,
sig
,
capsule
)
pytensor/link/numba/dispatch/scalar.py
浏览文件 @
98be9c5f
import
math
import
warnings
from
functools
import
reduce
from
typing
import
List
import
numpy
as
np
import
scipy
import
scipy.special
from
pytensor
import
config
from
pytensor.compile.ops
import
ViewOp
...
...
@@ -16,6 +12,7 @@ from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl
,
numba_funcify
,
)
from
pytensor.link.numba.dispatch.cython_support
import
wrap_cython_function
from
pytensor.link.utils
import
(
compile_function_src
,
get_name_for_object
,
...
...
@@ -41,86 +38,83 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name
=
op
.
nfunc_spec
[
0
]
scalar_func
=
None
if
scalar_func_name
.
startswith
(
"scipy."
):
func_package
=
scipy
scalar_func_name
=
scalar_func_name
.
split
(
"."
,
1
)[
-
1
]
use_numba_scipy
=
config
.
numba_scipy
if
use_numba_scipy
:
try
:
import
numba_scipy
# noqa: F401
except
ImportError
:
use_numba_scipy
=
False
if
not
use_numba_scipy
:
warnings
.
warn
(
"Native numba versions of scipy functions might be "
"avalable if numba-scipy is installed."
,
UserWarning
,
scalar_func_path
=
op
.
nfunc_spec
[
0
]
scalar_func_numba
=
None
*
module_path
,
scalar_func_name
=
scalar_func_path
.
split
(
"."
)
if
not
module_path
:
# Assume it is numpy, and numba has an implementation
scalar_func_numba
=
getattr
(
np
,
scalar_func_name
)
input_dtypes
=
[
np
.
dtype
(
input
.
type
.
dtype
)
for
input
in
node
.
inputs
]
output_dtypes
=
[
np
.
dtype
(
output
.
type
.
dtype
)
for
output
in
node
.
outputs
]
if
len
(
output_dtypes
)
!=
1
:
raise
ValueError
(
"ScalarOps with more than one output are not supported"
)
output_dtype
=
output_dtypes
[
0
]
input_inner_dtypes
=
None
output_inner_dtype
=
None
# Cython functions might have an additonal argument
has_pyx_skip_dispatch
=
False
if
scalar_func_path
.
startswith
(
"scipy.special"
):
import
scipy.special.cython_special
cython_func
=
getattr
(
scipy
.
special
.
cython_special
,
scalar_func_name
,
None
)
if
cython_func
is
not
None
:
# try:
scalar_func_numba
=
wrap_cython_function
(
cython_func
,
output_dtype
,
input_dtypes
)
scalar_func
=
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
else
:
func_package
=
np
has_pyx_skip_dispatch
=
scalar_func_numba
.
has_pyx_skip_dispatch
input_inner_dtypes
=
scalar_func_numba
.
numpy_arg_dtypes
()
output_inner_dtype
=
scalar_func_numba
.
numpy_output_dtype
()
# except NotImplementedError:
# pass
if
scalar_func
is
not
None
:
pass
elif
"."
in
scalar_func_name
:
scalar_func
=
reduce
(
getattr
,
[
scipy
]
+
scalar_func_name
.
split
(
"."
))
else
:
scalar_func
=
getattr
(
func_package
,
scalar_func_name
)
if
scalar_func_numba
is
None
:
scalar_func_numba
=
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
scalar_op_fn_name
=
get_name_for_object
(
scalar_func
)
scalar_op_fn_name
=
get_name_for_object
(
scalar_func
_numba
)
unique_names
=
unique_name_generator
(
[
scalar_op_fn_name
,
"scalar_func"
],
suffix_sep
=
"_"
[
scalar_op_fn_name
,
"scalar_func
_numba
"
],
suffix_sep
=
"_"
)
global_env
=
{
"scalar_func
"
:
scalar_func
}
global_env
=
{
"scalar_func
_numba"
:
scalar_func_numba
}
input_tmp_dtypes
=
None
if
func_package
==
scipy
and
hasattr
(
scalar_func
,
"types"
):
# The `numba-scipy` bindings don't provide implementations for all
# inputs types, so we need to convert the inputs to floats and back.
inp_dtype_kinds
=
tuple
(
np
.
dtype
(
inp
.
type
.
dtype
)
.
kind
for
inp
in
node
.
inputs
)
accepted_inp_kinds
=
tuple
(
sig_type
.
split
(
"->"
)[
0
]
for
sig_type
in
scalar_func
.
types
)
if
not
any
(
all
(
dk
==
ik
for
dk
,
ik
in
zip
(
inp_dtype_kinds
,
ok_kinds
))
for
ok_kinds
in
accepted_inp_kinds
):
# They're usually ordered from lower-to-higher precision, so
# we pick the last acceptable input types
#
# XXX: We should pick the first acceptable float/int types in
# reverse, excluding all the incompatible ones (e.g. `"0"`).
# The assumption is that this is only used by `numba-scipy`-exposed
# functions, although it's possible for this to be triggered by
# something else from the `scipy` package
input_tmp_dtypes
=
tuple
(
np
.
dtype
(
k
)
for
k
in
accepted_inp_kinds
[
-
1
])
if
input_tmp_dtypes
is
None
:
if
input_inner_dtypes
is
None
and
output_inner_dtype
is
None
:
unique_names
=
unique_name_generator
(
[
scalar_op_fn_name
,
"scalar_func"
],
suffix_sep
=
"_"
[
scalar_op_fn_name
,
"scalar_func
_numba
"
],
suffix_sep
=
"_"
)
input_names
=
", "
.
join
(
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
)
scalar_op_src
=
f
"""
if
not
has_pyx_skip_dispatch
:
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names})
"""
else
:
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
return scalar_func_numba({input_names}, np.intc(1))
"""
else
:
global_env
[
"direct_cast"
]
=
numba_basic
.
direct_cast
global_env
[
"output_dtype"
]
=
np
.
dtype
(
node
.
outputs
[
0
]
.
type
.
dtype
)
global_env
[
"output_dtype"
]
=
np
.
dtype
(
output_inner_
dtype
)
input_tmp_dtype_names
=
{
f
"inp_tmp_dtype_{i}"
:
i_dtype
for
i
,
i_dtype
in
enumerate
(
input_tmp_dtypes
)
f
"inp_tmp_dtype_{i}"
:
i_dtype
for
i
,
i_dtype
in
enumerate
(
input_inner_dtypes
)
}
global_env
.
update
(
input_tmp_dtype_names
)
unique_names
=
unique_name_generator
(
[
scalar_op_fn_name
,
"scalar_func"
]
+
list
(
global_env
.
keys
()),
suffix_sep
=
"_"
[
scalar_op_fn_name
,
"scalar_func_numba"
]
+
list
(
global_env
.
keys
()),
suffix_sep
=
"_"
,
)
input_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
...
...
@@ -132,10 +126,16 @@ def {scalar_op_fn_name}({input_names}):
)
]
)
scalar_op_src
=
f
"""
if
not
has_pyx_skip_dispatch
:
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({', '.join(input_names)}):
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
"""
else
:
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({', '.join(input_names)}):
return direct_cast(scalar_func
({converted_call_args}
), output_dtype)
"""
return direct_cast(scalar_func
_numba({converted_call_args}, np.intc(1)
), output_dtype)
"""
scalar_op_fn
=
compile_function_src
(
scalar_op_src
,
scalar_op_fn_name
,
{
**
globals
(),
**
global_env
}
...
...
tests/link/numba/test_cython_support.py
0 → 100644
浏览文件 @
98be9c5f
import
numpy
as
np
import
pytest
import
scipy.special.cython_special
from
numba.types
import
float32
,
float64
,
int32
,
int64
from
aesara.link.numba.dispatch.cython_support
import
Signature
,
wrap_cython_function
@pytest.mark.parametrize
(
"sig, expected_result, expected_args"
,
[
(
b
"double(double)"
,
np
.
float64
,
[
np
.
float64
]),
(
b
"float(unsigned int)"
,
np
.
float32
,
[
np
.
uintc
]),
(
b
"unsigned char(unsigned short foo)"
,
np
.
ubyte
,
[
np
.
ushort
]),
(
b
"unsigned char(unsigned short foo, double bar)"
,
np
.
ubyte
,
[
np
.
ushort
,
np
.
float64
],
),
],
)
def
test_parse_signature
(
sig
,
expected_result
,
expected_args
):
actual
=
Signature
.
from_c_types
(
sig
)
assert
actual
.
res_dtype
==
expected_result
assert
actual
.
arg_dtypes
==
expected_args
@pytest.mark.parametrize
(
"have, want, should_provide"
,
[
(
b
"double(int)"
,
b
"float(int)"
,
True
),
(
b
"float(int)"
,
b
"double(int)"
,
False
),
(
b
"double(unsigned short)"
,
b
"double(unsigned char)"
,
True
),
(
b
"double(unsigned char)"
,
b
"double(short)"
,
False
),
(
b
"short(double)"
,
b
"int(double)"
,
True
),
(
b
"int(double)"
,
b
"short(double)"
,
False
),
(
b
"float(double, int)"
,
b
"float(double, short)"
,
True
),
],
)
def
test_signature_provides
(
have
,
want
,
should_provide
):
have
=
Signature
.
from_c_types
(
have
)
want
=
Signature
.
from_c_types
(
want
)
provides
=
have
.
provides
(
want
.
res_dtype
,
want
.
arg_dtypes
)
assert
provides
==
should_provide
@pytest.mark.parametrize
(
"func, output, inputs, expected"
,
[
(
scipy
.
special
.
cython_special
.
agm
,
np
.
float64
,
[
np
.
float64
,
np
.
float64
],
float64
(
float64
,
float64
,
int32
),
),
(
scipy
.
special
.
cython_special
.
erfc
,
np
.
float64
,
[
np
.
float64
],
float64
(
float64
,
int32
),
),
(
scipy
.
special
.
cython_special
.
expit
,
np
.
float32
,
[
np
.
float32
],
float32
(
float32
,
int32
),
),
(
scipy
.
special
.
cython_special
.
expit
,
np
.
float64
,
[
np
.
float64
],
float64
(
float64
,
int32
),
),
(
# expn doesn't have a float32 implementation
scipy
.
special
.
cython_special
.
expn
,
np
.
float32
,
[
np
.
float32
,
np
.
float32
],
float64
(
float64
,
float64
,
int32
),
),
(
# We choose the integer implementation if possible
scipy
.
special
.
cython_special
.
expn
,
np
.
float32
,
[
np
.
int64
,
np
.
float32
],
float64
(
int64
,
float64
,
int32
),
),
],
)
def
test_choose_signature
(
func
,
output
,
inputs
,
expected
):
wrapper
=
wrap_cython_function
(
func
,
output
,
inputs
)
assert
wrapper
.
signature
()
==
expected
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论