提交 57c388a7 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to LogSoftmax

上级 595ed184
......@@ -208,8 +208,10 @@ def jax_funcify_Softmax(op, **kwargs):
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x):
return jax.nn.log_softmax(x)
return jax.nn.log_softmax(x, axis=axis)
return log_softmax
......
......@@ -430,15 +430,22 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1, keepdims=True)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True)
if axis is not None:
reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit
def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1)
xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
......
......@@ -27,7 +27,6 @@ from aesara.tensor.nnet.basic import (
graph_merge_softmax_with_crossentropy_softmax,
h_softmax,
logsoftmax,
logsoftmax_op,
prepend_0_to_each_row,
prepend_1_to_each_row,
prepend_scalar_to_each_row,
......
......@@ -969,16 +969,21 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
out = aet_nnet.logsoftmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -1930,20 +1930,27 @@ def test_Softmax(x, axis, exc):
@pytest.mark.parametrize(
"x, exc",
"x, axis, exc",
[
(
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None,
),
],
)
def test_LogSoftmax(x, exc):
g = nnetb.LogSoftmax()(x)
def test_LogSoftmax(x, axis, exc):
g = nnetb.LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
......@@ -47,7 +47,6 @@ from aesara.tensor.nnet.basic import (
elu,
h_softmax,
logsoftmax,
logsoftmax_op,
relu,
selu,
sigmoid_binary_crossentropy,
......@@ -205,47 +204,28 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
class TestLogSoftmax(utt.InferShapeTester):
def test_basic(self):
def f(a):
return logsoftmax_op(a)[:, 0]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 1]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 2]
utt.verify_grad(f, [np.random.random((3, 4))])
def f(a):
return logsoftmax_op(a)[:, 3]
utt.verify_grad(f, [np.random.random((3, 4))])
def test_matrix(self):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return logsoftmax_op(a)
return logsoftmax(a, axis=axis)[:, column]
utt.verify_grad(f, [np.random.random((3, 4))])
def test_vector(self):
def test_vector_perform(self):
x = vector()
f = aesara.function([x], logsoftmax_op(x))
f = aesara.function([x], logsoftmax(x, axis=None))
xv = np.random.randn(6).astype(config.floatX)
assert np.allclose(f(xv), np.log(np.exp(xv) / np.exp(xv).sum()))
assert np.allclose(f(xv), sp.log_softmax(xv))
def test_vector_grad(self):
def f(a):
return logsoftmax_op(a)
return logsoftmax(a, axis=None)
utt.verify_grad(f, [np.random.random((4))])
def test_allclose(self):
def test_matrix_perform_and_opt(self):
m = config.mode
m = aesara.compile.get_mode(m)
m.check_isfinite = False
......@@ -284,18 +264,15 @@ class TestLogSoftmax(utt.InferShapeTester):
grad_ = f3(a, b)
assert not np.any(np.isnan(grad_))
def test_isclose(self):
def f(a):
return logsoftmax_op(a)
def test_local_softmax_optimization(self):
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_opt(self, axis):
# Test the Logsoftmax substitution
#
# Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that
# only the forward pass is checked (i.e., doesn't check the gradient)
x, y = matrices("xy")
sm = softmax(x)
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = aesara.function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
......@@ -351,6 +328,9 @@ class TestLogSoftmax(utt.InferShapeTester):
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_valid_axis(self):
valid_axis_tester(LogSoftmax)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论