提交 4a1010ef authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add SoftmaxGrad numba dispatch

上级 f06146ae
...@@ -21,7 +21,7 @@ from aesara.link.utils import ( ...@@ -21,7 +21,7 @@ from aesara.link.utils import (
) )
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import tensor from aesara.tensor.type import tensor
...@@ -424,6 +424,31 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -424,6 +424,31 @@ def numba_funcify_Softmax(op, node, **kwargs):
return softmax return softmax
@numba_funcify.register(SoftmaxGrad)
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_at = node.inputs[1]
sm_dtype = sm_at.type.numpy_dtype
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis
if axis is not None:
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
)
else:
reduce_sum = np.sum
@numba.njit
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm
return dx
return softmax_grad
@numba_funcify.register(LogSoftmax) @numba_funcify.register(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs): def numba_funcify_LogSoftmax(op, node, **kwargs):
......
...@@ -1893,6 +1893,51 @@ def test_Dot(x, y, exc): ...@@ -1893,6 +1893,51 @@ def test_Dot(x, y, exc):
) )
@pytest.mark.parametrize(
"dy, sm, axis, exc",
[
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None,
),
],
)
def test_SoftmaxGrad(dy, sm, axis, exc):
g = nnetb.SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, axis, exc", "x, axis, exc",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论