提交 f06146ae authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add SoftmaxGrad jax dispatch

上级 bafd9638
......@@ -47,7 +47,7 @@ from aesara.tensor.extra_ops import (
)
from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
......@@ -206,6 +206,17 @@ def jax_funcify_Softmax(op, **kwargs):
return softmax
@jax_funcify.register(SoftmaxGrad)
def jax_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - jnp.sum(dy_times_sm, axis=axis, keepdims=True) * sm
return softmax_grad
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
......
......@@ -34,6 +34,7 @@ from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log
from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
......@@ -988,6 +989,17 @@ def test_logsoftmax(axis):
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论