Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d4696e6b
提交
d4696e6b
authored
9月 17, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split aesara.link.numba.dispatch into distinct modules
上级
0ae63d1b
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
1192 行增加
和
12 行删除
+1192
-12
dispatch.py
aesara/link/numba/dispatch.py
+0
-0
__init__.py
aesara/link/numba/dispatch/__init__.py
+12
-0
basic.py
aesara/link/numba/dispatch/basic.py
+0
-0
elemwise.py
aesara/link/numba/dispatch/elemwise.py
+0
-0
extra_ops.py
aesara/link/numba/dispatch/extra_ops.py
+354
-0
nlinalg.py
aesara/link/numba/dispatch/nlinalg.py
+195
-0
random.py
aesara/link/numba/dispatch/random.py
+235
-0
scalar.py
aesara/link/numba/dispatch/scalar.py
+172
-0
tensor_basic.py
aesara/link/numba/dispatch/tensor_basic.py
+211
-0
linker.py
aesara/link/numba/linker.py
+2
-2
test_numba.py
tests/link/test_numba.py
+11
-10
没有找到文件。
aesara/link/numba/dispatch.py
deleted
100644 → 0
浏览文件 @
0ae63d1b
差异被折叠。
点击展开。
aesara/link/numba/dispatch/__init__.py
0 → 100644
浏览文件 @
d4696e6b
# isort: off
from
aesara.link.numba.dispatch.basic
import
numba_funcify
,
numba_typify
# Load dispatch specializations
import
aesara.link.numba.dispatch.scalar
import
aesara.link.numba.dispatch.tensor_basic
import
aesara.link.numba.dispatch.extra_ops
import
aesara.link.numba.dispatch.nlinalg
import
aesara.link.numba.dispatch.random
import
aesara.link.numba.dispatch.elemwise
# isort: on
aesara/link/numba/dispatch/basic.py
0 → 100644
浏览文件 @
d4696e6b
差异被折叠。
点击展开。
aesara/link/numba/dispatch/elemwise.py
0 → 100644
浏览文件 @
d4696e6b
差异被折叠。
点击展开。
aesara/link/numba/dispatch/extra_ops.py
0 → 100644
浏览文件 @
d4696e6b
import
warnings
import
numba
import
numpy
as
np
from
numpy.core.multiarray
import
normalize_axis_index
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
get_numba_type
,
numba_funcify
from
aesara.tensor.extra_ops
import
(
Bartlett
,
CumOp
,
DiffOp
,
FillDiagonal
,
FillDiagonalOffset
,
RavelMultiIndex
,
Repeat
,
SearchsortedOp
,
Unique
,
UnravelIndex
,
)
@numba_funcify.register
(
Bartlett
)
def
numba_funcify_Bartlett
(
op
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
bartlett
(
x
):
return
np
.
bartlett
(
numba_basic
.
to_scalar
(
x
))
return
bartlett
@numba_funcify.register
(
CumOp
)
def
numba_funcify_CumOp
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
mode
=
op
.
mode
ndim
=
node
.
outputs
[
0
]
.
ndim
reaxis_first
=
(
axis
,)
+
tuple
(
i
for
i
in
range
(
ndim
)
if
i
!=
axis
)
if
mode
==
"add"
:
np_func
=
np
.
add
identity
=
0
else
:
np_func
=
np
.
multiply
identity
=
1
@numba.njit
(
boundscheck
=
False
)
def
cumop
(
x
):
out_dtype
=
x
.
dtype
if
x
.
shape
[
axis
]
<
2
:
return
x
.
astype
(
out_dtype
)
x_axis_first
=
x
.
transpose
(
reaxis_first
)
res
=
np
.
empty
(
x_axis_first
.
shape
,
dtype
=
out_dtype
)
for
m
in
range
(
x
.
shape
[
axis
]):
if
m
==
0
:
np_func
(
identity
,
x_axis_first
[
m
],
res
[
m
])
else
:
np_func
(
res
[
m
-
1
],
x_axis_first
[
m
],
res
[
m
])
return
res
.
transpose
(
reaxis_first
)
return
cumop
@numba_funcify.register
(
DiffOp
)
def
numba_funcify_DiffOp
(
op
,
node
,
**
kwargs
):
n
=
op
.
n
axis
=
op
.
axis
ndim
=
node
.
inputs
[
0
]
.
ndim
dtype
=
node
.
outputs
[
0
]
.
dtype
axis
=
normalize_axis_index
(
axis
,
ndim
)
slice1
=
[
slice
(
None
)]
*
ndim
slice2
=
[
slice
(
None
)]
*
ndim
slice1
[
axis
]
=
slice
(
1
,
None
)
slice2
[
axis
]
=
slice
(
None
,
-
1
)
slice1
=
tuple
(
slice1
)
slice2
=
tuple
(
slice2
)
op
=
np
.
not_equal
if
dtype
==
"bool"
else
np
.
subtract
@numba.njit
(
boundscheck
=
False
)
def
diffop
(
x
):
res
=
x
.
copy
()
for
_
in
range
(
n
):
res
=
op
(
res
[
slice1
],
res
[
slice2
])
return
res
return
diffop
@numba_funcify.register
(
FillDiagonal
)
def
numba_funcify_FillDiagonal
(
op
,
**
kwargs
):
@numba.njit
def
filldiagonal
(
a
,
val
):
np
.
fill_diagonal
(
a
,
val
)
return
a
return
filldiagonal
@numba_funcify.register
(
FillDiagonalOffset
)
def
numba_funcify_FillDiagonalOffset
(
op
,
node
,
**
kwargs
):
@numba.njit
def
filldiagonaloffset
(
a
,
val
,
offset
):
height
,
width
=
a
.
shape
if
offset
>=
0
:
start
=
numba_basic
.
to_scalar
(
offset
)
num_of_step
=
min
(
min
(
width
,
height
),
width
-
offset
)
else
:
start
=
-
numba_basic
.
to_scalar
(
offset
)
*
a
.
shape
[
1
]
num_of_step
=
min
(
min
(
width
,
height
),
height
+
offset
)
step
=
a
.
shape
[
1
]
+
1
end
=
start
+
step
*
num_of_step
b
=
a
.
ravel
()
b
[
start
:
end
:
step
]
=
val
# TODO: This isn't implemented in Numba
# a.flat[start:end:step] = val
# return a
return
b
.
reshape
(
a
.
shape
)
return
filldiagonaloffset
@numba_funcify.register
(
RavelMultiIndex
)
def
numba_funcify_RavelMultiIndex
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
order
=
op
.
order
if
order
!=
"C"
:
raise
NotImplementedError
(
"Numba does not implement `order` in `numpy.ravel_multi_index`"
)
if
mode
==
"raise"
:
@numba.njit
def
mode_fn
(
*
args
):
raise
ValueError
(
"invalid entry in coordinates array"
)
elif
mode
==
"wrap"
:
@numba.njit
(
inline
=
"always"
)
def
mode_fn
(
new_arr
,
i
,
j
,
v
,
d
):
new_arr
[
i
,
j
]
=
v
%
d
elif
mode
==
"clip"
:
@numba.njit
(
inline
=
"always"
)
def
mode_fn
(
new_arr
,
i
,
j
,
v
,
d
):
new_arr
[
i
,
j
]
=
min
(
max
(
v
,
0
),
d
-
1
)
if
node
.
inputs
[
0
]
.
ndim
==
0
:
@numba.njit
def
ravelmultiindex
(
*
inp
):
shape
=
inp
[
-
1
]
arr
=
np
.
stack
(
inp
[:
-
1
])
new_arr
=
arr
.
T
.
astype
(
np
.
float64
)
.
copy
()
for
i
,
b
in
enumerate
(
new_arr
):
if
b
<
0
or
b
>=
shape
[
i
]:
mode_fn
(
new_arr
,
i
,
0
,
b
,
shape
[
i
])
a
=
np
.
ones
(
len
(
shape
),
dtype
=
np
.
float64
)
a
[:
len
(
shape
)
-
1
]
=
np
.
cumprod
(
shape
[
-
1
:
0
:
-
1
])[::
-
1
]
return
np
.
array
(
a
.
dot
(
new_arr
.
T
),
dtype
=
np
.
int64
)
else
:
@numba.njit
def
ravelmultiindex
(
*
inp
):
shape
=
inp
[
-
1
]
arr
=
np
.
stack
(
inp
[:
-
1
])
new_arr
=
arr
.
T
.
astype
(
np
.
float64
)
.
copy
()
for
i
,
b
in
enumerate
(
new_arr
):
for
j
,
(
d
,
v
)
in
enumerate
(
zip
(
shape
,
b
)):
if
v
<
0
or
v
>=
d
:
mode_fn
(
new_arr
,
i
,
j
,
v
,
d
)
a
=
np
.
ones
(
len
(
shape
),
dtype
=
np
.
float64
)
a
[:
len
(
shape
)
-
1
]
=
np
.
cumprod
(
shape
[
-
1
:
0
:
-
1
])[::
-
1
]
return
a
.
dot
(
new_arr
.
T
)
.
astype
(
np
.
int64
)
return
ravelmultiindex
@numba_funcify.register
(
Repeat
)
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
use_python
=
False
if
axis
is
not
None
:
use_python
=
True
if
use_python
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning
,
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba.njit
def
repeatop
(
x
,
repeats
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
repeat
(
x
,
repeats
,
axis
)
return
ret
else
:
repeats_ndim
=
node
.
inputs
[
1
]
.
ndim
if
repeats_ndim
==
0
:
@numba.njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
.
item
())
else
:
@numba.njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
)
return
repeatop
@numba_funcify.register
(
Unique
)
def
numba_funcify_Unique
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
use_python
=
False
if
axis
is
not
None
:
use_python
=
True
return_index
=
op
.
return_index
return_inverse
=
op
.
return_inverse
return_counts
=
op
.
return_counts
returns_multi
=
return_index
or
return_inverse
or
return_counts
use_python
|=
returns_multi
if
not
use_python
:
@numba.njit
(
inline
=
"always"
)
def
unique
(
x
):
return
np
.
unique
(
x
)
else
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`axis` and/or `return_*` arguments to `numpy.unique`."
),
UserWarning
,
)
if
returns_multi
:
ret_sig
=
numba
.
types
.
Tuple
([
get_numba_type
(
o
.
type
)
for
o
in
node
.
outputs
])
else
:
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba.njit
def
unique
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
unique
(
x
,
return_index
,
return_inverse
,
return_counts
,
axis
)
return
ret
return
unique
@numba_funcify.register
(
UnravelIndex
)
def
numba_funcify_UnravelIndex
(
op
,
node
,
**
kwargs
):
order
=
op
.
order
if
order
!=
"C"
:
raise
NotImplementedError
(
"Numba does not support the `order` argument in `numpy.unravel_index`"
)
if
len
(
node
.
outputs
)
==
1
:
@numba.njit
(
inline
=
"always"
)
def
maybe_expand_dim
(
arr
):
return
arr
else
:
@numba.njit
(
inline
=
"always"
)
def
maybe_expand_dim
(
arr
):
return
np
.
expand_dims
(
arr
,
1
)
@numba.njit
def
unravelindex
(
arr
,
shape
):
a
=
np
.
ones
(
len
(
shape
),
dtype
=
np
.
int64
)
a
[
1
:]
=
shape
[:
0
:
-
1
]
a
=
np
.
cumprod
(
a
)[::
-
1
]
# Aesara actually returns a `tuple` of these values, instead of an
# `ndarray`; however, this `ndarray` result should be able to be
# unpacked into a `tuple`, so this discrepancy shouldn't really matter
return
((
maybe_expand_dim
(
arr
)
//
a
)
%
shape
)
.
T
return
unravelindex
@numba_funcify.register
(
SearchsortedOp
)
def
numba_funcify_Searchsorted
(
op
,
node
,
**
kwargs
):
side
=
op
.
side
use_python
=
False
if
len
(
node
.
inputs
)
==
3
:
use_python
=
True
if
use_python
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`sorter` argument to `numpy.searchsorted`."
),
UserWarning
,
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba.njit
def
searchsorted
(
a
,
v
,
sorter
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
searchsorted
(
a
,
v
,
side
,
sorter
)
return
ret
else
:
@numba.njit
(
inline
=
"always"
)
def
searchsorted
(
a
,
v
):
return
np
.
searchsorted
(
a
,
v
,
side
)
return
searchsorted
aesara/link/numba/dispatch/nlinalg.py
0 → 100644
浏览文件 @
d4696e6b
import
warnings
import
numba
import
numpy
as
np
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
(
get_numba_type
,
int_to_float_fn
,
numba_funcify
,
)
from
aesara.tensor.nlinalg
import
(
SVD
,
Det
,
Eig
,
Eigh
,
Inv
,
MatrixInverse
,
MatrixPinv
,
QRFull
,
)
@numba_funcify.register
(
SVD
)
def
numba_funcify_SVD
(
op
,
node
,
**
kwargs
):
full_matrices
=
op
.
full_matrices
compute_uv
=
op
.
compute_uv
if
not
compute_uv
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning
,
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba.njit
def
svd
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
linalg
.
svd
(
x
,
full_matrices
,
compute_uv
)
return
ret
else
:
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
svd
(
x
):
return
np
.
linalg
.
svd
(
inputs_cast
(
x
),
full_matrices
)
return
svd
@numba_funcify.register
(
Det
)
def
numba_funcify_Det
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
det
(
x
):
return
numba_basic
.
direct_cast
(
np
.
linalg
.
det
(
inputs_cast
(
x
)),
out_dtype
)
return
det
@numba_funcify.register
(
Eig
)
def
numba_funcify_Eig
(
op
,
node
,
**
kwargs
):
out_dtype_1
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
out_dtype_2
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype_1
)
@numba.njit
def
eig
(
x
):
out
=
np
.
linalg
.
eig
(
inputs_cast
(
x
))
return
(
out
[
0
]
.
astype
(
out_dtype_1
),
out
[
1
]
.
astype
(
out_dtype_2
))
return
eig
@numba_funcify.register
(
Eigh
)
def
numba_funcify_Eigh
(
op
,
node
,
**
kwargs
):
uplo
=
op
.
UPLO
if
uplo
!=
"L"
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`UPLO` argument to `numpy.linalg.eigh`."
),
UserWarning
,
)
out_dtypes
=
tuple
(
o
.
type
.
numpy_dtype
for
o
in
node
.
outputs
)
ret_sig
=
numba
.
types
.
Tuple
(
[
get_numba_type
(
node
.
outputs
[
0
]
.
type
),
get_numba_type
(
node
.
outputs
[
1
]
.
type
)]
)
@numba.njit
def
eigh
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
out
=
np
.
linalg
.
eigh
(
x
,
UPLO
=
uplo
)
ret
=
(
out
[
0
]
.
astype
(
out_dtypes
[
0
]),
out
[
1
]
.
astype
(
out_dtypes
[
1
]))
return
ret
else
:
@numba.njit
(
inline
=
"always"
)
def
eigh
(
x
):
return
np
.
linalg
.
eigh
(
x
)
return
eigh
@numba_funcify.register
(
Inv
)
def
numba_funcify_Inv
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
inv
(
x
):
return
np
.
linalg
.
inv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
inv
@numba_funcify.register
(
MatrixInverse
)
def
numba_funcify_MatrixInverse
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
matrix_inverse
(
x
):
return
np
.
linalg
.
inv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrix_inverse
@numba_funcify.register
(
MatrixPinv
)
def
numba_funcify_MatrixPinv
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
matrixpinv
(
x
):
return
np
.
linalg
.
pinv
(
inputs_cast
(
x
))
.
astype
(
out_dtype
)
return
matrixpinv
@numba_funcify.register
(
QRFull
)
def
numba_funcify_QRFull
(
op
,
node
,
**
kwargs
):
mode
=
op
.
mode
if
mode
!=
"reduced"
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning
,
)
if
len
(
node
.
outputs
)
>
1
:
ret_sig
=
numba
.
types
.
Tuple
([
get_numba_type
(
o
.
type
)
for
o
in
node
.
outputs
])
else
:
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba.njit
def
qr_full
(
x
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
linalg
.
qr
(
x
,
mode
=
mode
)
return
ret
else
:
out_dtype
=
node
.
outputs
[
0
]
.
type
.
numpy_dtype
inputs_cast
=
int_to_float_fn
(
node
.
inputs
,
out_dtype
)
@numba.njit
(
inline
=
"always"
)
def
qr_full
(
x
):
return
np
.
linalg
.
qr
(
inputs_cast
(
x
))
return
qr_full
aesara/link/numba/dispatch/random.py
0 → 100644
浏览文件 @
d4696e6b
from
textwrap
import
dedent
,
indent
from
typing
import
Any
,
Callable
,
Dict
,
Optional
import
numba
import
numba.np.unsafe.ndarray
as
numba_ndarray
import
numpy
as
np
from
numba
import
_helperlib
from
numpy.random
import
RandomState
import
aesara.tensor.random.basic
as
aer
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
numba_funcify
,
numba_typify
from
aesara.link.utils
import
(
compile_function_src
,
get_name_for_object
,
unique_name_generator
,
)
from
aesara.tensor.basic
import
get_vector_length
from
aesara.tensor.random.type
import
RandomStateType
from
aesara.tensor.random.var
import
RandomStateSharedVariable
@numba_typify.register
(
RandomState
)
def
numba_typify_RandomState
(
state
,
**
kwargs
):
ints
,
index
=
state
.
get_state
()[
1
:
3
]
ptr
=
_helperlib
.
rnd_get_np_state_ptr
()
_helperlib
.
rnd_set_state
(
ptr
,
(
index
,
[
int
(
x
)
for
x
in
ints
]))
return
ints
def
make_numba_random_fn
(
node
,
np_random_func
):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
tuple_size
=
get_vector_length
(
node
.
inputs
[
1
])
size_dims
=
tuple_size
-
max
(
i
.
ndim
for
i
in
node
.
inputs
[
3
:])
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name
=
f
"aesara_random_{get_name_for_object(np_random_func)}"
sized_fn_name
=
"sized_random_variable"
unique_names
=
unique_name_generator
(
[
bcast_fn_name
,
sized_fn_name
,
"np"
,
"np_random_func"
,
"numba_vectorize"
,
"to_fixed_tuple"
,
"tuple_size"
,
"size_dims"
,
"rng"
,
"size"
,
"dtype"
,
],
suffix_sep
=
"_"
,
)
bcast_fn_input_names
=
", "
.
join
(
[
unique_names
(
i
,
force_unique
=
True
)
for
i
in
node
.
inputs
[
3
:]]
)
bcast_fn_global_env
=
{
"np_random_func"
:
np_random_func
,
"numba_vectorize"
:
numba
.
vectorize
,
}
bcast_fn_src
=
f
"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn
=
compile_function_src
(
bcast_fn_src
,
bcast_fn_name
,
bcast_fn_global_env
)
random_fn_input_names
=
", "
.
join
(
[
"rng"
,
"size"
,
"dtype"
]
+
[
unique_names
(
i
)
for
i
in
node
.
inputs
[
3
:]]
)
# Now, create a Numba JITable function that implements the `size` parameter
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
random_fn_global_env
=
{
bcast_fn_name
:
bcast_fn
,
"out_dtype"
:
out_dtype
,
}
if
tuple_size
>
0
:
random_fn_body
=
dedent
(
f
"""
size = to_fixed_tuple(size, tuple_size)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env
.
update
(
{
"np"
:
np
,
"to_fixed_tuple"
:
numba_ndarray
.
to_fixed_tuple
,
"tuple_size"
:
tuple_size
,
"size_dims"
:
size_dims
,
}
)
else
:
random_fn_body
=
f
"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src
=
dedent
(
f
"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn
=
compile_function_src
(
sized_fn_src
,
sized_fn_name
,
random_fn_global_env
)
random_fn
=
numba
.
njit
(
random_fn
)
return
random_fn
@numba_funcify.register
(
aer
.
UniformRV
)
@numba_funcify.register
(
aer
.
TriangularRV
)
@numba_funcify.register
(
aer
.
BetaRV
)
@numba_funcify.register
(
aer
.
NormalRV
)
@numba_funcify.register
(
aer
.
LogNormalRV
)
@numba_funcify.register
(
aer
.
GammaRV
)
@numba_funcify.register
(
aer
.
ChiSquareRV
)
@numba_funcify.register
(
aer
.
ParetoRV
)
@numba_funcify.register
(
aer
.
GumbelRV
)
@numba_funcify.register
(
aer
.
ExponentialRV
)
@numba_funcify.register
(
aer
.
WeibullRV
)
@numba_funcify.register
(
aer
.
LogisticRV
)
@numba_funcify.register
(
aer
.
VonMisesRV
)
@numba_funcify.register
(
aer
.
PoissonRV
)
@numba_funcify.register
(
aer
.
GeometricRV
)
@numba_funcify.register
(
aer
.
HyperGeometricRV
)
@numba_funcify.register
(
aer
.
CauchyRV
)
@numba_funcify.register
(
aer
.
WaldRV
)
@numba_funcify.register
(
aer
.
LaplaceRV
)
@numba_funcify.register
(
aer
.
BinomialRV
)
@numba_funcify.register
(
aer
.
NegBinomialRV
)
@numba_funcify.register
(
aer
.
MultinomialRV
)
@numba_funcify.register
(
aer
.
RandIntRV
)
# only the first two arguments are supported
@numba_funcify.register
(
aer
.
ChoiceRV
)
# the `p` argument is not supported
@numba_funcify.register
(
aer
.
PermutationRV
)
def
numba_funcify_RandomVariable
(
op
,
node
,
**
kwargs
):
name
=
op
.
name
np_random_func
=
getattr
(
np
.
random
,
name
)
if
not
isinstance
(
node
.
inputs
[
0
],
(
RandomStateType
,
RandomStateSharedVariable
)):
raise
TypeError
(
"Numba does not support NumPy `Generator`s"
)
return
make_numba_random_fn
(
node
,
np_random_func
)
def
create_numba_random_fn
(
op
:
Op
,
node
:
Apply
,
scalar_fn
:
Callable
[[
str
],
str
],
global_env
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Callable
:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name
=
f
"aesara_random_{get_name_for_object(op.name)}"
if
global_env
:
np_global_env
=
global_env
.
copy
()
else
:
np_global_env
=
{}
np_global_env
[
"np"
]
=
np
np_global_env
[
"numba_vectorize"
]
=
numba
.
vectorize
unique_names
=
unique_name_generator
(
[
np_random_fn_name
,
]
+
list
(
np_global_env
.
keys
())
+
[
"rng"
,
"size"
,
"dtype"
,
],
suffix_sep
=
"_"
,
)
np_names
=
[
unique_names
(
i
,
force_unique
=
True
)
for
i
in
node
.
inputs
[
3
:]]
np_input_names
=
", "
.
join
(
np_names
)
np_random_fn_src
=
f
"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn
=
compile_function_src
(
np_random_fn_src
,
np_random_fn_name
,
np_global_env
)
return
make_numba_random_fn
(
node
,
np_random_fn
)
@numba_funcify.register
(
aer
.
HalfNormalRV
)
def
numba_funcify_HalfNormalRV
(
op
,
node
,
**
kwargs
):
def
body_fn
(
a
,
b
):
return
f
" return {a} + {b} * abs(np.random.normal(0, 1))"
return
create_numba_random_fn
(
op
,
node
,
body_fn
)
@numba_funcify.register
(
aer
.
BernoulliRV
)
def
numba_funcify_BernoulliRV
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
def
body_fn
(
a
):
return
f
"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return
create_numba_random_fn
(
op
,
node
,
body_fn
,
{
"out_dtype"
:
out_dtype
,
"direct_cast"
:
numba_basic
.
direct_cast
},
)
aesara/link/numba/dispatch/scalar.py
0 → 100644
浏览文件 @
d4696e6b
from
functools
import
reduce
from
typing
import
List
import
numba
import
numpy
as
np
import
scipy
import
scipy.special
from
aesara.compile.ops
import
ViewOp
from
aesara.graph.basic
import
Variable
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
create_numba_signature
,
numba_funcify
from
aesara.link.utils
import
(
compile_function_src
,
get_name_for_object
,
unique_name_generator
,
)
from
aesara.scalar.basic
import
(
Add
,
Cast
,
Clip
,
Composite
,
Identity
,
Mul
,
ScalarOp
,
Second
,
Switch
,
)
@numba_funcify.register
(
ScalarOp
)
def
numba_funcify_ScalarOp
(
op
,
node
,
**
kwargs
):
# TODO: Do we need to cache these functions so that we don't end up
# compiling the same Numba function over and over again?
scalar_func_name
=
op
.
nfunc_spec
[
0
]
if
scalar_func_name
.
startswith
(
"scipy."
):
func_package
=
scipy
scalar_func_name
=
scalar_func_name
.
split
(
"."
,
1
)[
-
1
]
else
:
func_package
=
np
if
"."
in
scalar_func_name
:
scalar_func
=
reduce
(
getattr
,
[
scipy
]
+
scalar_func_name
.
split
(
"."
))
else
:
scalar_func
=
getattr
(
func_package
,
scalar_func_name
)
scalar_op_fn_name
=
get_name_for_object
(
scalar_func
)
unique_names
=
unique_name_generator
(
[
scalar_op_fn_name
,
"scalar_func"
],
suffix_sep
=
"_"
)
input_names
=
", "
.
join
([
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
])
global_env
=
{
"scalar_func"
:
scalar_func
}
scalar_op_src
=
f
"""
def {scalar_op_fn_name}({input_names}):
return scalar_func({input_names})
"""
scalar_op_fn
=
compile_function_src
(
scalar_op_src
,
scalar_op_fn_name
,
global_env
)
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
return
numba
.
njit
(
signature
,
inline
=
"always"
)(
scalar_op_fn
)
@numba_funcify.register
(
Switch
)
def
numba_funcify_Switch
(
op
,
node
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
switch
(
condition
,
x
,
y
):
if
condition
:
return
x
else
:
return
y
return
switch
def
binary_to_nary_func
(
inputs
:
List
[
Variable
],
binary_op_name
:
str
,
binary_op
:
str
):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names
=
unique_name_generator
([
"binary_op_name"
],
suffix_sep
=
"_"
)
input_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
inputs
]
input_signature
=
", "
.
join
(
input_names
)
output_expr
=
binary_op
.
join
(
input_names
)
nary_src
=
f
"""
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn
=
compile_function_src
(
nary_src
,
binary_op_name
)
return
nary_fn
@numba_funcify.register
(
Add
)
def
numba_funcify_Add
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_add_fn
=
binary_to_nary_func
(
node
.
inputs
,
"add"
,
"+"
)
return
numba
.
njit
(
signature
,
inline
=
"always"
)(
nary_add_fn
)
@numba_funcify.register
(
Mul
)
def
numba_funcify_Mul
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
nary_mul_fn
=
binary_to_nary_func
(
node
.
inputs
,
"mul"
,
"*"
)
return
numba
.
njit
(
signature
,
inline
=
"always"
)(
nary_mul_fn
)
@numba_funcify.register
(
Cast
)
def
numba_funcify_Cast
(
op
,
node
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
o_type
.
dtype
)
@numba.njit
(
inline
=
"always"
)
def
cast
(
x
):
return
numba_basic
.
direct_cast
(
x
,
dtype
)
return
cast
@numba_funcify.register
(
Identity
)
@numba_funcify.register
(
ViewOp
)
def
numba_funcify_ViewOp
(
op
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
viewop
(
x
):
return
x
return
viewop
@numba_funcify.register
(
Clip
)
def
numba_funcify_Clip
(
op
,
**
kwargs
):
@numba.njit
def
clip
(
_x
,
_min
,
_max
):
x
=
numba_basic
.
to_scalar
(
_x
)
_min_scalar
=
numba_basic
.
to_scalar
(
_min
)
_max_scalar
=
numba_basic
.
to_scalar
(
_max
)
if
x
<
_min_scalar
:
return
_min_scalar
elif
x
>
_max_scalar
:
return
_max_scalar
else
:
return
x
return
clip
@numba_funcify.register
(
Composite
)
def
numba_funcify_Composite
(
op
,
node
,
**
kwargs
):
signature
=
create_numba_signature
(
node
,
force_scalar
=
True
)
composite_fn
=
numba
.
njit
(
signature
)(
numba_funcify
(
op
.
fgraph
,
squeeze_output
=
True
,
**
kwargs
)
)
return
composite_fn
@numba_funcify.register
(
Second
)
def
numba_funcify_Second
(
op
,
node
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
second
(
x
,
y
):
return
y
return
second
aesara/link/numba/dispatch/tensor_basic.py
0 → 100644
浏览文件 @
d4696e6b
from
textwrap
import
indent
import
numba
import
numpy
as
np
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
create_tuple_string
,
numba_funcify
from
aesara.link.utils
import
compile_function_src
,
unique_name_generator
from
aesara.tensor.basic
import
(
Alloc
,
AllocDiag
,
AllocEmpty
,
ARange
,
ExtractDiag
,
Eye
,
Join
,
MakeVector
,
Rebroadcast
,
ScalarFromTensor
,
TensorFromScalar
,
)
@numba_funcify.register
(
AllocEmpty
)
def
numba_funcify_AllocEmpty
(
op
,
node
,
**
kwargs
):
global_env
=
{
"np"
:
np
,
"to_scalar"
:
numba_basic
.
to_scalar
,
"dtype"
:
np
.
dtype
(
op
.
dtype
),
}
unique_names
=
unique_name_generator
(
[
"np"
,
"to_scalar"
,
"dtype"
,
"allocempty"
,
"scalar_shape"
],
suffix_sep
=
"_"
)
shape_var_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
]
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shapes_to_items_src
=
indent
(
"
\n
"
.
join
(
[
f
"{item_name} = to_scalar({shape_name})"
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
]
),
" "
*
4
,
)
alloc_def_src
=
f
"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
return np.empty(scalar_shape, dtype)
"""
alloc_fn
=
compile_function_src
(
alloc_def_src
,
"allocempty"
,
global_env
)
return
numba
.
njit
(
alloc_fn
)
@numba_funcify.register
(
Alloc
)
def
numba_funcify_Alloc
(
op
,
node
,
**
kwargs
):
global_env
=
{
"np"
:
np
,
"to_scalar"
:
numba_basic
.
to_scalar
}
unique_names
=
unique_name_generator
(
[
"np"
,
"to_scalar"
,
"alloc"
,
"val_np"
,
"val"
,
"scalar_shape"
,
"res"
],
suffix_sep
=
"_"
,
)
shape_var_names
=
[
unique_names
(
v
,
force_unique
=
True
)
for
v
in
node
.
inputs
[
1
:]]
shape_var_item_names
=
[
f
"{name}_item"
for
name
in
shape_var_names
]
shapes_to_items_src
=
indent
(
"
\n
"
.
join
(
[
f
"{item_name} = to_scalar({shape_name})"
for
item_name
,
shape_name
in
zip
(
shape_var_item_names
,
shape_var_names
)
]
),
" "
*
4
,
)
alloc_def_src
=
f
"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn
=
compile_function_src
(
alloc_def_src
,
"alloc"
,
global_env
)
return
numba
.
njit
(
alloc_fn
)
@numba_funcify.register
(
AllocDiag
)
def
numba_funcify_AllocDiag
(
op
,
**
kwargs
):
offset
=
op
.
offset
@numba.njit
(
inline
=
"always"
)
def
allocdiag
(
v
):
return
np
.
diag
(
v
,
k
=
offset
)
return
allocdiag
@numba_funcify.register
(
ARange
)
def
numba_funcify_ARange
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba.njit
(
inline
=
"always"
)
def
arange
(
start
,
stop
,
step
):
return
np
.
arange
(
numba_basic
.
to_scalar
(
start
),
numba_basic
.
to_scalar
(
stop
),
numba_basic
.
to_scalar
(
step
),
dtype
=
dtype
,
)
return
arange
@numba_funcify.register
(
Join
)
def
numba_funcify_Join
(
op
,
**
kwargs
):
view
=
op
.
view
if
view
!=
-
1
:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise
NotImplementedError
(
"The `view` parameter to `Join` is not supported"
)
@numba.njit
def
join
(
axis
,
*
tensors
):
return
np
.
concatenate
(
tensors
,
numba_basic
.
to_scalar
(
axis
))
return
join
@numba_funcify.register
(
ExtractDiag
)
def
numba_funcify_ExtractDiag
(
op
,
**
kwargs
):
offset
=
op
.
offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit
(
inline
=
"always"
)
def
extract_diag
(
x
):
return
np
.
diag
(
x
,
k
=
offset
)
return
extract_diag
@numba_funcify.register
(
Eye
)
def
numba_funcify_Eye
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba.njit
(
inline
=
"always"
)
def
eye
(
N
,
M
,
k
):
return
np
.
eye
(
numba_basic
.
to_scalar
(
N
),
numba_basic
.
to_scalar
(
M
),
numba_basic
.
to_scalar
(
k
),
dtype
=
dtype
,
)
return
eye
@numba_funcify.register
(
MakeVector
)
def
numba_funcify_MakeVector
(
op
,
**
kwargs
):
dtype
=
np
.
dtype
(
op
.
dtype
)
@numba.njit
def
makevector
(
*
args
):
return
np
.
array
([
a
.
item
()
for
a
in
args
],
dtype
=
dtype
)
return
makevector
@numba_funcify.register
(
Rebroadcast
)
def
numba_funcify_Rebroadcast
(
op
,
**
kwargs
):
op_axis
=
tuple
(
op
.
axis
.
items
())
@numba.njit
def
rebroadcast
(
x
):
for
axis
,
value
in
numba
.
literal_unroll
(
op_axis
):
if
value
and
x
.
shape
[
axis
]
!=
1
:
raise
ValueError
(
(
"Dimension in Rebroadcast's input was supposed to be 1"
)
)
return
x
return
rebroadcast
@numba_funcify.register
(
TensorFromScalar
)
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
tensor_from_scalar
(
x
):
return
np
.
array
(
x
)
return
tensor_from_scalar
@numba_funcify.register
(
ScalarFromTensor
)
def
numba_funcify_ScalarFromTensor
(
op
,
**
kwargs
):
@numba.njit
(
inline
=
"always"
)
def
scalar_from_tensor
(
x
):
return
x
.
item
()
return
scalar_from_tensor
aesara/link/numba/linker.py
浏览文件 @
d4696e6b
from
numpy.random
import
RandomState
from
aesara.link.basic
import
JITLinker
from
aesara.link.basic
import
JITLinker
...
@@ -18,6 +16,8 @@ class NumbaLinker(JITLinker):
...
@@ -18,6 +16,8 @@ class NumbaLinker(JITLinker):
return
jitted_fn
return
jitted_fn
def
create_thunk_inputs
(
self
,
storage_map
):
def
create_thunk_inputs
(
self
,
storage_map
):
from
numpy.random
import
RandomState
from
aesara.link.numba.dispatch
import
numba_typify
from
aesara.link.numba.dispatch
import
numba_typify
thunk_inputs
=
[]
thunk_inputs
=
[]
...
...
tests/link/test_numba.py
浏览文件 @
d4696e6b
...
@@ -25,7 +25,7 @@ from aesara.graph.fg import FunctionGraph
...
@@ -25,7 +25,7 @@ from aesara.graph.fg import FunctionGraph
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.optdb
import
OptimizationQuery
from
aesara.graph.type
import
Type
from
aesara.graph.type
import
Type
from
aesara.link.numba.dispatch
import
create_numba_signature
,
get_numba_type
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.linker
import
NumbaLinker
from
aesara.link.numba.linker
import
NumbaLinker
from
aesara.scalar.basic
import
Composite
from
aesara.scalar.basic
import
Composite
from
aesara.tensor
import
blas
from
aesara.tensor
import
blas
...
@@ -147,20 +147,21 @@ def eval_python_only(fn_inputs, fgraph, inputs):
...
@@ -147,20 +147,21 @@ def eval_python_only(fn_inputs, fgraph, inputs):
else
:
else
:
return
wrap
return
wrap
with
mock
.
patch
(
"
aesara.link.numba.dispatch.
numba.njit"
,
njit_noop
),
mock
.
patch
(
with
mock
.
patch
(
"numba.njit"
,
njit_noop
),
mock
.
patch
(
"
aesara.link.numba.dispatch.
numba.vectorize"
,
"numba.vectorize"
,
vectorize_noop
,
vectorize_noop
,
),
mock
.
patch
(
),
mock
.
patch
(
"aesara.link.numba.dispatch.tuple_setitem"
,
py_tuple_setitem
"aesara.link.numba.dispatch.elemwise.tuple_setitem"
,
py_tuple_setitem
,
),
mock
.
patch
(
),
mock
.
patch
(
"aesara.link.numba.dispatch.direct_cast"
,
lambda
x
,
dtype
:
x
"aesara.link.numba.dispatch.
basic.
direct_cast"
,
lambda
x
,
dtype
:
x
),
mock
.
patch
(
),
mock
.
patch
(
"aesara.link.numba.dispatch.numba.np.numpy_support.from_dtype"
,
"aesara.link.numba.dispatch.
basic.
numba.np.numpy_support.from_dtype"
,
lambda
dtype
:
dtype
,
lambda
dtype
:
dtype
,
),
mock
.
patch
(
),
mock
.
patch
(
"aesara.link.numba.dispatch.to_scalar"
,
py_to_scalar
"aesara.link.numba.dispatch.
basic.
to_scalar"
,
py_to_scalar
),
mock
.
patch
(
),
mock
.
patch
(
"
aesara.link.numba.dispatch
.to_fixed_tuple"
,
"
numba.np.unsafe.ndarray
.to_fixed_tuple"
,
lambda
x
,
n
:
tuple
(
x
),
lambda
x
,
n
:
tuple
(
x
),
):
):
aesara_numba_fn
=
function
(
aesara_numba_fn
=
function
(
...
@@ -247,7 +248,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
...
@@ -247,7 +248,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
else
pytest
.
raises
(
NotImplementedError
)
else
pytest
.
raises
(
NotImplementedError
)
)
)
with
cm
:
with
cm
:
res
=
get_numba_type
(
v
,
force_scalar
=
force_scalar
)
res
=
numba_basic
.
get_numba_type
(
v
,
force_scalar
=
force_scalar
)
assert
res
==
expected
assert
res
==
expected
...
@@ -289,7 +290,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
...
@@ -289,7 +290,7 @@ def test_get_numba_type(v, expected, force_scalar, not_implemented):
],
],
)
)
def
test_create_numba_signature
(
v
,
expected
,
force_scalar
):
def
test_create_numba_signature
(
v
,
expected
,
force_scalar
):
res
=
create_numba_signature
(
v
,
force_scalar
=
force_scalar
)
res
=
numba_basic
.
create_numba_signature
(
v
,
force_scalar
=
force_scalar
)
assert
res
==
expected
assert
res
==
expected
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论