提交 58cb5c30 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to Softmax and SoftmaxGrad Ops

上级 c6c85acb
...@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs): ...@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register(Softmax) @jax_funcify.register(Softmax)
def jax_funcify_Softmax(op, **kwargs): def jax_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x): def softmax(x):
return jax.nn.softmax(x) return jax.nn.softmax(x, axis=axis)
return softmax return softmax
......
...@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1) if axis is not None:
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype) reduce_max = create_axis_reducer(
# np.sum(x, axis=1) np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype) )
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit @numba.njit
def softmax(x): def softmax(x):
z = np.expand_dims(reduce_max(x), -1) z = reduce_max(x)
e_x = np.exp(x - z) e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1) w = reduce_sum(e_x)
sm = e_x / w sm = e_x / w
return sm return sm
......
...@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import ( ...@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import (
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
softmax, softmax,
softmax_grad, softmax_grad_legacy,
softmax_graph, softmax_graph,
softmax_op, softmax_legacy,
softmax_simplifier, softmax_simplifier,
softmax_with_bias, softmax_with_bias,
softsign, softsign,
......
...@@ -32,7 +32,7 @@ from aesara.tensor.math import ( ...@@ -32,7 +32,7 @@ from aesara.tensor.math import (
sqrt, sqrt,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_op from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_legacy
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
get_conv_gradinputs_shape, get_conv_gradinputs_shape,
get_conv_output_shape, get_conv_output_shape,
...@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax): ...@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
def test_softmax_f16(self): def test_softmax_f16(self):
x = matrix("x", "float16") x = matrix("x", "float16")
x_gpu = tensor4("x_gpu", "float16") x_gpu = tensor4("x_gpu", "float16")
f_z = softmax_op f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel") f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
def cmp(n, m, f, f_gpu): def cmp(n, m, f, f_gpu):
...@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax): ...@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
x = matrix("x") x = matrix("x")
x_gpu = tensor4("x_gpu") x_gpu = tensor4("x_gpu")
f_z = softmax_op f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel") f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
# Verify the grad operation # Verify the grad operation
......
...@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput): ...@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput):
z = aesara.tensor.nnet.softmax(x) z = aesara.tensor.nnet.softmax(x)
f = aesara.function([x], z, mode=mode_without_gpu) f = aesara.function([x], z, mode=mode_without_gpu)
f_gpu = aesara.function([x], z, mode=mode_wo_cudnn) f_gpu = aesara.function([x], z, mode=mode_wo_cudnn)
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_op assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_legacy
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax) assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax)
def cmp(n, m): def cmp(n, m):
...@@ -300,7 +300,7 @@ class TestSoftMax: ...@@ -300,7 +300,7 @@ class TestSoftMax:
def test_softmax(self): def test_softmax(self):
x = fmatrix("x") x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp) f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
...@@ -308,7 +308,7 @@ class TestSoftMax: ...@@ -308,7 +308,7 @@ class TestSoftMax:
def test_softmax_shape_0(self): def test_softmax_shape_0(self):
x = fmatrix("x") x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp) f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
# Aesara can handle that case, but cudnn can't # Aesara can handle that case, but cudnn can't
......
...@@ -969,11 +969,16 @@ def test_nnet(): ...@@ -969,11 +969,16 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softmax(x) out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc): ...@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, axis, exc",
[ [
( (
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
), ),
( (
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None, None,
), ),
], ],
) )
def test_Softmax(x, exc): def test_Softmax(x, axis, exc):
g = nnetb.Softmax()(x) g = nnetb.Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
...@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker): ...@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self): def test_softmax(self):
# Softmax adds an extra dimnesion ! self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape)
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x)[0], self.in_shape[0])
def test_alloc(self): def test_alloc(self):
# Alloc of the sum of x into a vector # Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论