提交 4efbd193 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor SoftmaxGrad numba patch

上级 da66c2ef
......@@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}):
return careduce_fn
def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
def jit_compile_reducer(
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
......@@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
An node from which the signature can be derived.
fn
The Python function object to compile.
reduce_to_scalar: bool, default False
Whether to reduce output to a scalar (instead of 0d array)
infer_signature: bool: default True
Whether to try and infer the function signature from the Apply node.
kwds
Extra keywords to be added to the :func:`numba.njit` function.
......@@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
A :func:`numba.njit`-compiled function.
"""
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
if infer_signature:
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
args = (signature,)
else:
args = ()
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
signature,
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
......@@ -926,11 +936,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
return dx
# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
softmax_grad = numba_njit(
boundscheck=False,
fastmath=config.numba__fastmath,
)(softmax_grad_py_fn)
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)
return softmax_grad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论