提交 4a1010ef authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add SoftmaxGrad numba dispatch

上级 f06146ae
......@@ -21,7 +21,7 @@ from aesara.link.utils import (
)
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import tensor
......@@ -424,6 +424,31 @@ def numba_funcify_Softmax(op, node, **kwargs):
return softmax
@numba_funcify.register(SoftmaxGrad)
def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_at = node.inputs[1]
sm_dtype = sm_at.type.numpy_dtype
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis
if axis is not None:
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
)
else:
reduce_sum = np.sum
@numba.njit
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm
return dx
return softmax_grad
@numba_funcify.register(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs):
......
......@@ -1893,6 +1893,51 @@ def test_Dot(x, y, exc):
)
@pytest.mark.parametrize(
"dy, sm, axis, exc",
[
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(
aet.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None,
),
],
)
def test_SoftmaxGrad(dy, sm, axis, exc):
g = nnetb.SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, axis, exc",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论