提交 58cb5c30 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to Softmax and SoftmaxGrad Ops

上级 c6c85acb
......@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x):
return jax.nn.softmax(x)
return jax.nn.softmax(x, axis=axis)
return softmax
......
......@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype)
if axis is not None:
reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit
def softmax(x):
z = np.expand_dims(reduce_max(x), -1)
z = reduce_max(x)
e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1)
w = reduce_sum(e_x)
sm = e_x / w
return sm
......
......@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import (
selu,
sigmoid_binary_crossentropy,
softmax,
softmax_grad,
softmax_grad_legacy,
softmax_graph,
softmax_op,
softmax_legacy,
softmax_simplifier,
softmax_with_bias,
softsign,
......
......@@ -32,7 +32,7 @@ from aesara.tensor.math import (
sqrt,
)
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_op
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_legacy
from aesara.tensor.nnet.abstract_conv import (
get_conv_gradinputs_shape,
get_conv_output_shape,
......@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
def test_softmax_f16(self):
x = matrix("x", "float16")
x_gpu = tensor4("x_gpu", "float16")
f_z = softmax_op
f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
def cmp(n, m, f, f_gpu):
......@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
x = matrix("x")
x_gpu = tensor4("x_gpu")
f_z = softmax_op
f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
# Verify the grad operation
......
......@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput):
z = aesara.tensor.nnet.softmax(x)
f = aesara.function([x], z, mode=mode_without_gpu)
f_gpu = aesara.function([x], z, mode=mode_wo_cudnn)
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_op
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_legacy
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax)
def cmp(n, m):
......@@ -300,7 +300,7 @@ class TestSoftMax:
def test_softmax(self):
x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op
z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
......@@ -308,7 +308,7 @@ class TestSoftMax:
def test_softmax_shape_0(self):
x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op
z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
# Aesara can handle that case, but cudnn can't
......
......@@ -969,11 +969,16 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softmax(x)
out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc):
@pytest.mark.parametrize(
"x, exc",
"x, axis, exc",
[
(
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
],
)
def test_Softmax(x, exc):
g = nnetb.Softmax()(x)
def test_Softmax(x, axis, exc):
g = nnetb.Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
......@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self):
# Softmax adds an extra dimnesion !
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x)[0], self.in_shape[0])
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape)
def test_alloc(self):
# Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论