Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
33998b20
提交
33998b20
authored
1月 07, 2022
作者:
kc611
提交者:
Brandon T. Willard
1月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added more optimizations to the Numba cheap pass-manager
This only applies to reduction `Op`s (e.g. `CAReduce`).
上级
66760618
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
337 行增加
和
178 行删除
+337
-178
basic.py
aesara/link/numba/dispatch/basic.py
+51
-6
elemwise.py
aesara/link/numba/dispatch/elemwise.py
+222
-147
test_numba.py
tests/link/test_numba.py
+64
-25
没有找到文件。
aesara/link/numba/dispatch/basic.py
浏览文件 @
33998b20
import
operator
import
warnings
from
contextlib
import
contextmanager
from
functools
import
singledispatch
import
numba
...
...
@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs):
def
get_numba_type
(
aesara_type
:
Type
,
layout
:
str
=
"A"
,
force_scalar
:
bool
=
False
aesara_type
:
Type
,
layout
:
str
=
"A"
,
force_scalar
:
bool
=
False
,
reduce_to_scalar
:
bool
=
False
,
)
->
numba
.
types
.
Type
:
"""Create a Numba type object for a ``Type``."""
r"""Create a Numba type object for a :class:`Type`.
Parameters
----------
aesara_type
The :class:`Type` to convert.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Ignore dimension information and return the corresponding Numba scalar types.
reduce_to_scalar
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""
if
isinstance
(
aesara_type
,
TensorType
):
dtype
=
aesara_type
.
numpy_dtype
numba_dtype
=
numba
.
from_dtype
(
dtype
)
if
force_scalar
:
if
force_scalar
or
(
reduce_to_scalar
and
getattr
(
aesara_type
,
"ndim"
,
None
)
==
0
):
return
numba_dtype
return
numba
.
types
.
Array
(
numba_dtype
,
aesara_type
.
ndim
,
layout
)
elif
isinstance
(
aesara_type
,
Scalar
):
...
...
@@ -75,15 +93,25 @@ def get_numba_type(
raise
NotImplementedError
(
f
"Numba type not implemented for {aesara_type}"
)
def
create_numba_signature
(
node
:
Apply
,
force_scalar
:
bool
=
False
)
->
numba
.
types
.
Type
:
def
create_numba_signature
(
node
:
Apply
,
force_scalar
:
bool
=
False
,
reduce_to_scalar
:
bool
=
False
)
->
numba
.
types
.
Type
:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types
=
[]
for
inp
in
node
.
inputs
:
input_types
.
append
(
get_numba_type
(
inp
.
type
,
force_scalar
=
force_scalar
))
input_types
.
append
(
get_numba_type
(
inp
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
)
output_types
=
[]
for
out
in
node
.
outputs
:
output_types
.
append
(
get_numba_type
(
out
.
type
,
force_scalar
=
force_scalar
))
output_types
.
append
(
get_numba_type
(
out
.
type
,
force_scalar
=
force_scalar
,
reduce_to_scalar
=
reduce_to_scalar
)
)
if
len
(
output_types
)
>
1
:
return
numba
.
types
.
Tuple
(
output_types
)(
*
input_types
)
...
...
@@ -263,6 +291,23 @@ def create_arg_string(x):
return
args
@contextmanager
def
use_optimized_cheap_pass
(
*
args
,
**
kwargs
):
"""Temporarily replace the cheap optimization pass with a better one."""
from
numba.core.registry
import
cpu_target
context
=
cpu_target
.
target_context
.
_internal_codegen
old_pm
=
context
.
_mpm_cheap
new_pm
=
context
.
_module_pass_manager
(
loop_vectorize
=
True
,
slp_vectorize
=
True
,
opt
=
3
,
cost
=
"cheap"
)
context
.
_mpm_cheap
=
new_pm
try
:
yield
finally
:
context
.
_mpm_cheap
=
old_pm
@singledispatch
def
numba_typify
(
data
,
dtype
=
None
,
**
kwargs
):
return
data
...
...
aesara/link/numba/dispatch/elemwise.py
浏览文件 @
33998b20
...
...
@@ -9,12 +9,14 @@ import numpy as np
from
numba.cpython.unsafe.tuple
import
tuple_setitem
from
aesara
import
config
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.link.numba.dispatch
import
basic
as
numba_basic
from
aesara.link.numba.dispatch.basic
import
(
create_numba_signature
,
create_tuple_creator
,
numba_funcify
,
use_optimized_cheap_pass
,
)
from
aesara.link.utils
import
(
compile_function_src
,
...
...
@@ -27,99 +29,20 @@ from aesara.scalar.basic import (
XOR
,
Add
,
IntDiv
,
Mean
,
Mul
,
ScalarMaximum
,
ScalarMinimum
,
Sub
,
TrueDiv
,
)
from
aesara.scalar.basic
import
add
as
add_as
from
aesara.scalar.basic
import
scalar_maximum
from
aesara.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
aesara.tensor.math
import
MaxAndArgmax
from
aesara.tensor.math
import
MaxAndArgmax
,
MulWithoutZeros
from
aesara.tensor.nnet.basic
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
def
create_vectorize_func
(
op
,
node
,
use_signature
=
False
,
identity
=
None
,
**
kwargs
):
scalar_op_fn
=
numba_funcify
(
op
.
scalar_op
,
node
=
node
,
inline
=
"always"
,
**
kwargs
)
if
len
(
node
.
outputs
)
>
1
:
raise
NotImplementedError
(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if
use_signature
:
signature
=
[
create_numba_signature
(
node
,
force_scalar
=
True
)]
else
:
signature
=
[]
target
=
(
getattr
(
node
.
tag
,
"numba__vectorize_target"
,
None
)
or
config
.
numba__vectorize_target
)
numba_vectorized_fn
=
numba_basic
.
numba_vectorize
(
signature
,
identity
=
identity
,
target
=
target
,
fastmath
=
config
.
numba__fastmath
)
py_scalar_func
=
getattr
(
scalar_op_fn
,
"py_func"
,
scalar_op_fn
)
elemwise_fn
=
numba_vectorized_fn
(
scalar_op_fn
)
elemwise_fn
.
py_scalar_func
=
py_scalar_func
return
elemwise_fn
@numba_funcify.register
(
Elemwise
)
def
numba_funcify_Elemwise
(
op
,
node
,
**
kwargs
):
elemwise_fn
=
create_vectorize_func
(
op
,
node
,
use_signature
=
False
)
elemwise_fn_name
=
elemwise_fn
.
__name__
if
op
.
inplace_pattern
:
input_idx
=
op
.
inplace_pattern
[
0
]
sign_obj
=
inspect
.
signature
(
elemwise_fn
.
py_scalar_func
)
input_names
=
list
(
sign_obj
.
parameters
.
keys
())
unique_names
=
unique_name_generator
([
elemwise_fn_name
,
"np"
],
suffix_sep
=
"_"
)
input_names
=
[
unique_names
(
i
,
force_unique
=
True
)
for
i
in
input_names
]
updated_input_name
=
input_names
[
input_idx
]
inplace_global_env
=
{
elemwise_fn_name
:
elemwise_fn
,
"np"
:
np
}
inplace_elemwise_fn_name
=
f
"{elemwise_fn_name}_inplace"
input_signature_str
=
", "
.
join
(
input_names
)
if
node
.
inputs
[
input_idx
]
.
ndim
>
0
:
inplace_elemwise_src
=
f
"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else
:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src
=
f
"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn
=
compile_function_src
(
inplace_elemwise_src
,
inplace_elemwise_fn_name
,
{
**
globals
(),
**
inplace_global_env
},
)
return
numba_basic
.
numba_njit
(
inline
=
"always"
,
fastmath
=
config
.
numba__fastmath
)(
inplace_elemwise_fn
)
return
elemwise_fn
@singledispatch
def
scalar_in_place_fn
(
op
:
Op
,
idx
:
str
,
res
:
str
,
arr
:
str
):
"""Return code for an in-place update on an array using a binary scalar :class:`Op`.
...
...
@@ -135,7 +58,7 @@ def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
arr
The symbol name for the second input.
"""
r
eturn
f
"{res}[{idx}] = {op.nfunc_spec[0]}({res}[{idx}], arr)"
r
aise
NotImplementedError
()
@scalar_in_place_fn.register
(
Add
)
...
...
@@ -143,14 +66,24 @@ def scalar_in_place_fn_Add(op, idx, res, arr):
return
f
"{res}[{idx}] += {arr}"
@scalar_in_place_fn.register
(
Sub
)
def
scalar_in_place_fn_Sub
(
op
,
idx
,
res
,
arr
):
return
f
"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register
(
Mean
)
def
scalar_in_place_fn_Mean
(
op
,
idx
,
res
,
arr
):
return
f
"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"
@scalar_in_place_fn.register
(
Mul
)
def
scalar_in_place_fn_Mul
(
op
,
idx
,
res
,
arr
):
return
f
"{res}[{idx}] *= {arr}"
@scalar_in_place_fn.register
(
Sub
)
def
scalar_in_place_fn_
Sub
(
op
,
idx
,
res
,
arr
):
return
f
"{res}[{idx}]
-= {arr}
"
@scalar_in_place_fn.register
(
MulWithoutZeros
)
def
scalar_in_place_fn_
MulWithoutZeros
(
op
,
idx
,
res
,
arr
):
return
f
"{res}[{idx}]
= {arr} if {res}[{idx}] == 0 else ({res}[{idx}] if {arr} == 0 else {res}[{idx}] * {arr})
"
@scalar_in_place_fn.register
(
AND
)
...
...
@@ -186,6 +119,44 @@ if {res}[{idx}] < {arr}:
"""
@scalar_in_place_fn.register
(
ScalarMinimum
)
def
scalar_in_place_fn_ScalarMinimum
(
op
,
idx
,
res
,
arr
):
return
f
"""
if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr}
"""
def
create_vectorize_func
(
op
,
node
,
use_signature
=
False
,
identity
=
None
,
**
kwargs
):
scalar_op_fn
=
numba_funcify
(
op
.
scalar_op
,
node
=
node
,
inline
=
"always"
,
**
kwargs
)
if
len
(
node
.
outputs
)
>
1
:
raise
NotImplementedError
(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if
use_signature
:
signature
=
[
create_numba_signature
(
node
,
force_scalar
=
True
)]
else
:
signature
=
[]
target
=
(
getattr
(
node
.
tag
,
"numba__vectorize_target"
,
None
)
or
config
.
numba__vectorize_target
)
numba_vectorized_fn
=
numba_basic
.
numba_vectorize
(
signature
,
identity
=
identity
,
target
=
target
,
fastmath
=
config
.
numba__fastmath
)
py_scalar_func
=
getattr
(
scalar_op_fn
,
"py_func"
,
scalar_op_fn
)
elemwise_fn
=
numba_vectorized_fn
(
scalar_op_fn
)
elemwise_fn
.
py_scalar_func
=
py_scalar_func
return
elemwise_fn
def
create_axis_reducer
(
scalar_op
:
Op
,
identity
:
Union
[
np
.
ndarray
,
Number
],
...
...
@@ -194,7 +165,7 @@ def create_axis_reducer(
dtype
:
numba
.
types
.
Type
,
keepdims
:
bool
=
False
,
)
->
numba
.
core
.
dispatcher
.
Dispatcher
:
r"""Create
a Numba JITed function that performs a NumPy
reduction on a given axis.
r"""Create
Python function that performs a NumPy-like
reduction on a given axis.
The functions generated by this function take the following form:
...
...
@@ -232,35 +203,15 @@ def create_axis_reducer(
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
"""
reduce_elemwise_fn_name
=
"careduce_axis"
if
ndim
>
1
:
res_shape_tuple_ctor
=
create_tuple_creator
(
lambda
i
,
shape
:
shape
[
i
]
if
i
<
axis
else
shape
[
i
+
1
],
ndim
-
1
)
if
keepdims
:
set_out_dims
=
numba_basic
.
numba_njit
(
lambda
x
:
np
.
expand_dims
(
x
,
axis
),
inline
=
"always"
)
else
:
set_out_dims
=
numba_basic
.
numba_njit
(
lambda
x
:
x
,
inline
=
"always"
)
else
:
Returns
=======
A Python function that can be JITed.
@numba_basic.numba_njit
def
res_shape_tuple_ctor
(
args
):
return
1
"""
if
keepdims
:
set_out_dims
=
numba_basic
.
numba_njit
(
lambda
x
:
numba_basic
.
direct_cast
(
x
,
dtype
),
inline
=
"always"
)
else
:
set_out_dims
=
numba_basic
.
numba_njit
(
lambda
x
:
numba_basic
.
direct_cast
(
x
[
0
],
dtype
),
inline
=
"always"
)
reduce_elemwise_fn_name
=
"careduce_axis"
identity
=
str
(
identity
)
if
identity
==
"inf"
:
...
...
@@ -268,7 +219,18 @@ def create_axis_reducer(
elif
identity
==
"-inf"
:
identity
=
"-np.inf"
global_env
=
{
"np"
:
np
,
"numba_basic"
:
numba_basic
,
"out_dtype"
:
dtype
,
}
if
ndim
>
1
:
res_shape_tuple_ctor
=
create_tuple_creator
(
lambda
i
,
shape
:
shape
[
i
]
if
i
<
axis
else
shape
[
i
+
1
],
ndim
-
1
)
global_env
[
"res_shape_tuple_ctor"
]
=
res_shape_tuple_ctor
res_indices
=
[]
arr_indices
=
[]
count
=
0
...
...
@@ -289,48 +251,45 @@ def create_axis_reducer(
)
inplace_update_statement
=
indent
(
inplace_update_statement
,
" "
*
4
*
3
)
return_expr
=
f
"np.expand_dims(res, {axis})"
if
keepdims
else
"res"
reduce_elemwise_def_src
=
f
"""
def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
x_shape = np.shape(x)
res_shape = res_shape_tuple_ctor(x_shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
{inplace_update_statement}
return
set_out_dims(res)
return
{return_expr}
"""
else
:
inplace_update_statement
=
scalar_in_place_fn
(
scalar_op
,
"0"
,
"res"
,
"x[i]"
)
inplace_update_statement
=
indent
(
inplace_update_statement
,
" "
*
4
*
3
)
inplace_update_statement
=
indent
(
inplace_update_statement
,
" "
*
4
*
2
)
return_expr
=
"res"
if
keepdims
else
"res.item()"
reduce_elemwise_def_src
=
f
"""
def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for i in range(axis_shape):
{inplace_update_statement}
{inplace_update_statement}
return
set_out_dims(res)
return
{return_expr}
"""
global_env
=
{
"np"
:
np
,
"res_shape_tuple_ctor"
:
res_shape_tuple_ctor
,
"numba_basic"
:
numba_basic
,
"set_out_dims"
:
set_out_dims
,
}
reduce_elemwise_fn_py
=
compile_function_src
(
reduce_elemwise_def_src
,
reduce_elemwise_fn_name
,
global_env
reduce_elemwise_def_src
,
reduce_elemwise_fn_name
,
{
**
globals
(),
**
global_env
}
)
return
numba_basic
.
numba_njit
(
boundscheck
=
False
)(
reduce_elemwise_fn_py
)
return
reduce_elemwise_fn_py
def
create_multiaxis_reducer
(
...
...
@@ -366,6 +325,10 @@ def create_multiaxis_reducer(
dtype:
The data type of the result.
Returns
=======
A Python function that can be JITed.
"""
if
len
(
axes
)
==
1
:
return
create_axis_reducer
(
scalar_op
,
identity
,
axes
[
0
],
ndim
,
dtype
)
...
...
@@ -378,9 +341,13 @@ def create_multiaxis_reducer(
for
i
,
axis
in
enumerate
(
to_reduce
):
careducer_axes_fn_name
=
f
"careduce_axes_fn_{i}"
global_env
[
careducer_axes_fn_name
]
=
create_axis_reducer
(
scalar_op
,
identity
,
axis
,
ndim
,
dtype
)
reducer_py_fn
=
create_axis_reducer
(
scalar_op
,
identity
,
axis
,
ndim
,
dtype
)
reducer_fn
=
numba_basic
.
numba_njit
(
boundscheck
=
False
,
fastmath
=
config
.
numba__fastmath
)(
reducer_py_fn
)
global_env
[
careducer_axes_fn_name
]
=
reducer_fn
ndim
-=
1
last_var_name
=
var_name
var_name
=
f
"axis_{i}_res"
...
...
@@ -398,7 +365,40 @@ def {careduce_fn_name}({input_name}):
careduce_fn
=
compile_function_src
(
careduce_def_src
,
careduce_fn_name
,
{
**
globals
(),
**
global_env
}
)
return
numba_basic
.
numba_njit
(
fastmath
=
config
.
numba__fastmath
)(
careduce_fn
)
return
careduce_fn
def
jit_compile_reducer
(
node
,
fn
,
**
kwds
):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
==========
node
An node from which the signature can be derived.
fn
The Python function object to compile.
kwds
Extra keywords to be added to the :func:`numba.njit` function.
Returns
=======
A :func:`numba.njit`-compiled function.
"""
signature
=
create_numba_signature
(
node
,
reduce_to_scalar
=
True
)
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with
use_optimized_cheap_pass
():
res
=
numba_basic
.
numba_njit
(
signature
,
boundscheck
=
False
,
fastmath
=
config
.
numba__fastmath
,
**
kwds
,
)(
fn
)
return
res
def
create_axis_apply_fn
(
fn
,
axis
,
ndim
,
dtype
):
...
...
@@ -417,6 +417,57 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return
axis_apply_fn
@numba_funcify.register
(
Elemwise
)
def
numba_funcify_Elemwise
(
op
,
node
,
**
kwargs
):
elemwise_fn
=
create_vectorize_func
(
op
,
node
,
use_signature
=
False
)
elemwise_fn_name
=
elemwise_fn
.
__name__
if
op
.
inplace_pattern
:
input_idx
=
op
.
inplace_pattern
[
0
]
sign_obj
=
inspect
.
signature
(
elemwise_fn
.
py_scalar_func
)
input_names
=
list
(
sign_obj
.
parameters
.
keys
())
unique_names
=
unique_name_generator
([
elemwise_fn_name
,
"np"
],
suffix_sep
=
"_"
)
input_names
=
[
unique_names
(
i
,
force_unique
=
True
)
for
i
in
input_names
]
updated_input_name
=
input_names
[
input_idx
]
inplace_global_env
=
{
elemwise_fn_name
:
elemwise_fn
,
"np"
:
np
}
inplace_elemwise_fn_name
=
f
"{elemwise_fn_name}_inplace"
input_signature_str
=
", "
.
join
(
input_names
)
if
node
.
inputs
[
input_idx
]
.
ndim
>
0
:
inplace_elemwise_src
=
f
"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else
:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src
=
f
"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn
=
compile_function_src
(
inplace_elemwise_src
,
inplace_elemwise_fn_name
,
{
**
globals
(),
**
inplace_global_env
},
)
return
numba_basic
.
numba_njit
(
inline
=
"always"
,
fastmath
=
config
.
numba__fastmath
)(
inplace_elemwise_fn
)
return
elemwise_fn
@numba_funcify.register
(
CAReduce
)
def
numba_funcify_CAReduce
(
op
,
node
,
**
kwargs
):
axes
=
op
.
axis
...
...
@@ -434,15 +485,16 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name
=
get_name_for_object
(
node
.
inputs
[
0
])
ndim
=
node
.
inputs
[
0
]
.
ndim
careduce_fn
=
create_multiaxis_reducer
(
careduce_
py_
fn
=
create_multiaxis_reducer
(
op
.
scalar_op
,
scalar_op_identity
,
axes
,
ndim
,
np
_acc_dtype
,
np
.
dtype
(
node
.
outputs
[
0
]
.
type
.
dtype
)
,
input_name
=
input_name
,
)
careduce_fn
=
jit_compile_reducer
(
node
,
careduce_py_fn
)
return
careduce_fn
...
...
@@ -533,24 +585,31 @@ def numba_funcify_Softmax(op, node, **kwargs):
axis
=
op
.
axis
if
axis
is
not
None
:
reduce_max
=
create_axis_reducer
(
reduce_max
_py
=
create_axis_reducer
(
scalar_maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
reduce_sum
=
create_axis_reducer
(
reduce_sum
_py
=
create_axis_reducer
(
add_as
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
jit_fn
=
numba_basic
.
numba_njit
(
boundscheck
=
False
,
fastmath
=
config
.
numba__fastmath
)
reduce_max
=
jit_fn
(
reduce_max_py
)
reduce_sum
=
jit_fn
(
reduce_sum_py
)
else
:
reduce_max
=
np
.
max
reduce_sum
=
np
.
sum
@numba_basic.numba_njit
def
softmax
(
x
):
def
softmax_py_fn
(
x
):
z
=
reduce_max
(
x
)
e_x
=
np
.
exp
(
x
-
z
)
w
=
reduce_sum
(
e_x
)
sm
=
e_x
/
w
return
sm
softmax
=
jit_compile_reducer
(
node
,
softmax_py_fn
)
return
softmax
...
...
@@ -563,19 +622,25 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis
=
op
.
axis
if
axis
is
not
None
:
reduce_sum
=
create_axis_reducer
(
reduce_sum
_py
=
create_axis_reducer
(
add_as
,
0.0
,
axis
,
sm_at
.
ndim
,
sm_dtype
,
keepdims
=
True
)
jit_fn
=
numba_basic
.
numba_njit
(
boundscheck
=
False
,
fastmath
=
config
.
numba__fastmath
)
reduce_sum
=
jit_fn
(
reduce_sum_py
)
else
:
reduce_sum
=
np
.
sum
@numba_basic.numba_njit
def
softmax_grad
(
dy
,
sm
):
def
softmax_grad_py_fn
(
dy
,
sm
):
dy_times_sm
=
dy
*
sm
sum_dy_times_sm
=
reduce_sum
(
dy_times_sm
)
dx
=
dy_times_sm
-
sum_dy_times_sm
*
sm
return
dx
softmax_grad
=
jit_compile_reducer
(
node
,
softmax_grad_py_fn
)
return
softmax_grad
...
...
@@ -588,22 +653,28 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
axis
=
op
.
axis
if
axis
is
not
None
:
reduce_max
=
create_axis_reducer
(
reduce_max
_py
=
create_axis_reducer
(
scalar_maximum
,
-
np
.
inf
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
reduce_sum
=
create_axis_reducer
(
reduce_sum
_py
=
create_axis_reducer
(
add_as
,
0.0
,
axis
,
x_at
.
ndim
,
x_dtype
,
keepdims
=
True
)
jit_fn
=
numba_basic
.
numba_njit
(
boundscheck
=
False
,
fastmath
=
config
.
numba__fastmath
)
reduce_max
=
jit_fn
(
reduce_max_py
)
reduce_sum
=
jit_fn
(
reduce_sum_py
)
else
:
reduce_max
=
np
.
max
reduce_sum
=
np
.
sum
@numba_basic.numba_njit
def
log_softmax
(
x
):
def
log_softmax_py_fn
(
x
):
xdev
=
x
-
reduce_max
(
x
)
lsm
=
xdev
-
np
.
log
(
reduce_sum
(
np
.
exp
(
xdev
)))
return
lsm
log_softmax
=
jit_compile_reducer
(
node
,
log_softmax_py_fn
)
return
log_softmax
...
...
@@ -629,9 +700,13 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around
keep_axes
=
tuple
(
i
for
i
in
range
(
x_ndim
)
if
i
not
in
axes
)
reduce_max
=
create_multiaxis_reducer
(
reduce_max
_py_fn
=
create_multiaxis_reducer
(
scalar_maximum
,
-
np
.
inf
,
axes
,
x_ndim
,
x_dtype
)
reduce_max
=
jit_compile_reducer
(
Apply
(
node
.
op
,
node
.
inputs
,
[
node
.
outputs
[
0
]
.
clone
()]),
reduce_max_py_fn
)
reduced_x_ndim
=
x_ndim
-
len
(
axes
)
+
1
argmax_axis
=
create_axis_apply_fn
(
np
.
argmax
,
reduced_x_ndim
-
1
,
reduced_x_ndim
,
np
.
int64
...
...
tests/link/test_numba.py
浏览文件 @
33998b20
...
...
@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise
from
aesara.tensor
import
extra_ops
,
nlinalg
,
slinalg
from
aesara.tensor
import
subtensor
as
at_subtensor
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
aesara.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
...
...
@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize
(
"careduce_fn, axis, v
, keepdims
"
,
"careduce_fn, axis, v"
,
[
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
False
,
),
(
at
.
all
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
All
(
axis
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Any
(
axis
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Mean
(
axis
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
False
,
),
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Mean
(
axis
)(
x
)
,
0
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
0
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
(
0
,
1
),
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
(
1
,
0
),
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
None
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
sum
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
1
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
prod
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
False
,
),
(
at
.
prod
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
ProdWithoutZeros
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
0
,
set_test_value
(
at
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
0
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
prod
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
1
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
False
,
),
(
at
.
max
,
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Max
(
axis
)(
x
),
None
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Min
(
axis
)(
x
),
None
,
set_test_value
(
at
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
True
,
),
],
)
def
test_CAReduce
(
careduce_fn
,
axis
,
v
,
keepdims
):
g
=
careduce_fn
(
v
,
axis
=
axis
,
keepdims
=
keepdims
)
def
test_CAReduce
(
careduce_fn
,
axis
,
v
):
g
=
careduce_fn
(
v
,
axis
=
axis
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论