提交 57c388a7 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to LogSoftmax

上级 595ed184
...@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs): ...@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs):
@jax_funcify.register(LogSoftmax) @jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op, **kwargs): def jax_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x): def log_softmax(x):
return jax.nn.log_softmax(x) return jax.nn.log_softmax(x, axis=axis)
return log_softmax return log_softmax
......
...@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1) if axis is not None:
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype) reduce_max = create_axis_reducer(
# np.sum(x, axis=1, keepdims=True) np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True) )
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit @numba.njit
def log_softmax(x): def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1) xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev))) lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm return lsm
......
...@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import (
graph_merge_softmax_with_crossentropy_softmax, graph_merge_softmax_with_crossentropy_softmax,
h_softmax, h_softmax,
logsoftmax, logsoftmax,
logsoftmax_op,
prepend_0_to_each_row, prepend_0_to_each_row,
prepend_1_to_each_row, prepend_1_to_each_row,
prepend_scalar_to_each_row, prepend_scalar_to_each_row,
......
...@@ -969,16 +969,21 @@ def test_nnet(): ...@@ -969,16 +969,21 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis): def test_logsoftmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis) out = aet_nnet.logsoftmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc): ...@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, axis, exc",
[ [
( (
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None,
), ),
( (
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None, None,
), ),
], ],
) )
def test_LogSoftmax(x, exc): def test_LogSoftmax(x, axis, exc):
g = nnetb.LogSoftmax()(x) g = nnetb.LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
...@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import (
elu, elu,
h_softmax, h_softmax,
logsoftmax, logsoftmax,
logsoftmax_op,
relu, relu,
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
...@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester): ...@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
class TestLogSoftmax(utt.InferShapeTester): class TestLogSoftmax(utt.InferShapeTester):
def test_basic(self): @pytest.mark.parametrize("column", [0, 1, 2, 3])
def f(a): @pytest.mark.parametrize("axis", [None, 0, 1])
return logsoftmax_op(a)[:, 0] def test_matrix_grad(self, axis, column):
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 1]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 2]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 3]
utt.verify_grad(f, [np.random.random((3, 4))])
def test_matrix(self):
def f(a): def f(a):
return logsoftmax_op(a) return logsoftmax(a, axis=axis)[:, column]
utt.verify_grad(f, [np.random.random((3, 4))]) utt.verify_grad(f, [np.random.random((3, 4))])
def test_vector(self): def test_vector_perform(self):
x = vector() x = vector()
f = aesara.function([x], logsoftmax_op(x)) f = aesara.function([x], logsoftmax(x, axis=None))
xv = np.random.randn(6).astype(config.floatX) xv = np.random.randn(6).astype(config.floatX)
assert np.allclose(f(xv), np.log(np.exp(xv) / np.exp(xv).sum())) assert np.allclose(f(xv), sp.log_softmax(xv))
def test_vector_grad(self): def test_vector_grad(self):
def f(a): def f(a):
return logsoftmax_op(a) return logsoftmax(a, axis=None)
utt.verify_grad(f, [np.random.random((4))]) utt.verify_grad(f, [np.random.random((4))])
def test_allclose(self): def test_matrix_perform_and_opt(self):
m = config.mode m = config.mode
m = aesara.compile.get_mode(m) m = aesara.compile.get_mode(m)
m.check_isfinite = False m.check_isfinite = False
...@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester):
grad_ = f3(a, b) grad_ = f3(a, b)
assert not np.any(np.isnan(grad_)) assert not np.any(np.isnan(grad_))
def test_isclose(self): @pytest.mark.parametrize("axis", [None, 0, -1])
def f(a): def test_local_logsoftmax_opt(self, axis):
return logsoftmax_op(a)
def test_local_softmax_optimization(self):
# Test the Logsoftmax substitution # Test the Logsoftmax substitution
# #
# Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that # Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that
# only the forward pass is checked (i.e., doesn't check the gradient) # only the forward pass is checked (i.e., doesn't check the gradient)
x, y = matrices("xy") x = matrix("x")
sm = softmax(x) sm = softmax(x, axis=axis)
logsm = log(sm) logsm = log(sm)
f = aesara.function([x], logsm) f = aesara.function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax) assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
...@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester):
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()] assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_valid_axis(self):
valid_axis_tester(LogSoftmax)
class TestSoftmaxGrad(utt.InferShapeTester): class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论