提交 c0614c17 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Move `Sotfmax`, `SoftmaxGrad` `LogSoftmax` to `aesara.tensor.math`

上级 f1cc8937
...@@ -3,7 +3,7 @@ import jax.numpy as jnp ...@@ -3,7 +3,7 @@ import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
......
...@@ -38,8 +38,13 @@ from aesara.scalar.basic import ( ...@@ -38,8 +38,13 @@ from aesara.scalar.basic import (
from aesara.scalar.basic import add as add_as from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros from aesara.tensor.math import (
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad LogSoftmax,
MaxAndArgmax,
MulWithoutZeros,
Softmax,
SoftmaxGrad,
)
@singledispatch @singledispatch
......
差异被折叠。
...@@ -24,7 +24,7 @@ class Mlp: ...@@ -24,7 +24,7 @@ class Mlp:
wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs))) wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True) by = shared(np.zeros(noutputs), borrow=True)
y = at.nnet.softmax(at.dot(h, wy) + by) y = at.softmax(at.dot(h, wy) + by)
self.inputs = [x] self.inputs = [x]
self.outputs = [y] self.outputs = [y]
......
...@@ -6,10 +6,10 @@ from aesara.graph.fg import FunctionGraph ...@@ -6,10 +6,10 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import nnet as at_nnet from aesara.tensor import nnet as at_nnet
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import prod from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
......
...@@ -6,14 +6,25 @@ import pytest ...@@ -6,14 +6,25 @@ import pytest
import aesara.tensor as at import aesara.tensor as at
import aesara.tensor.inplace as ati import aesara.tensor.inplace as ati
import aesara.tensor.math as aem import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb
from aesara import config from aesara import config
from aesara.compile.ops import deep_copy_op from aesara.compile.ops import deep_copy_op
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from aesara.tensor.math import (
All,
Any,
LogSoftmax,
Max,
Mean,
Min,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
)
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
my_multi_out, my_multi_out,
...@@ -377,7 +388,7 @@ def test_scalar_Elemwise_Clip(): ...@@ -377,7 +388,7 @@ def test_scalar_Elemwise_Clip():
], ],
) )
def test_SoftmaxGrad(dy, sm, axis, exc): def test_SoftmaxGrad(dy, sm, axis, exc):
g = nnetb.SoftmaxGrad(axis=axis)(dy, sm) g = SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
...@@ -413,7 +424,7 @@ def test_SoftmaxGrad(dy, sm, axis, exc): ...@@ -413,7 +424,7 @@ def test_SoftmaxGrad(dy, sm, axis, exc):
], ],
) )
def test_Softmax(x, axis, exc): def test_Softmax(x, axis, exc):
g = nnetb.Softmax(axis=axis)(x) g = Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
...@@ -449,7 +460,7 @@ def test_Softmax(x, axis, exc): ...@@ -449,7 +460,7 @@ def test_Softmax(x, axis, exc):
], ],
) )
def test_LogSoftmax(x, axis, exc): def test_LogSoftmax(x, axis, exc):
g = nnetb.LogSoftmax(axis=axis)(x) g = LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
...@@ -40,7 +40,7 @@ from aesara.scan.basic import scan ...@@ -40,7 +40,7 @@ from aesara.scan.basic import scan
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import dot, mean, sigmoid from aesara.tensor.math import dot, exp, mean, sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy from aesara.tensor.nnet import categorical_crossentropy
...@@ -69,7 +69,6 @@ from aesara.tensor.type import ( ...@@ -69,7 +69,6 @@ from aesara.tensor.type import (
vector, vector,
) )
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.nnet.test_basic import softmax_graph
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -85,6 +84,10 @@ else: ...@@ -85,6 +84,10 @@ else:
type_eps = {"float64": 1e-7, "float32": 3e-3} type_eps = {"float64": 1e-7, "float32": 3e-3}
def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
class multiple_outputs_numeric_grad: class multiple_outputs_numeric_grad:
"""WRITEME""" """WRITEME"""
......
...@@ -24,13 +24,12 @@ from aesara.tensor.math import ( ...@@ -24,13 +24,12 @@ from aesara.tensor.math import (
sigmoid, sigmoid,
) )
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, true_div from aesara.tensor.math import tanh
from aesara.tensor.nnet.basic import ( from aesara.tensor.nnet.basic import (
CrossentropyCategorical1Hot, CrossentropyCategorical1Hot,
CrossentropyCategorical1HotGrad, CrossentropyCategorical1HotGrad,
CrossentropySoftmax1HotWithBiasDx, CrossentropySoftmax1HotWithBiasDx,
CrossentropySoftmaxArgmax1HotWithBias, CrossentropySoftmaxArgmax1HotWithBias,
LogSoftmax,
Prepend_scalar_constant_to_each_row, Prepend_scalar_constant_to_each_row,
Prepend_scalar_to_each_row, Prepend_scalar_to_each_row,
Softmax, Softmax,
...@@ -46,7 +45,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -46,7 +45,6 @@ from aesara.tensor.nnet.basic import (
crossentropy_softmax_argmax_1hot_with_bias, crossentropy_softmax_argmax_1hot_with_bias,
elu, elu,
h_softmax, h_softmax,
logsoftmax,
relu, relu,
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
...@@ -65,7 +63,6 @@ from aesara.tensor.type import ( ...@@ -65,7 +63,6 @@ from aesara.tensor.type import (
fvector, fvector,
ivector, ivector,
lvector, lvector,
matrices,
matrix, matrix,
scalar, scalar,
tensor3, tensor3,
...@@ -104,52 +101,6 @@ def valid_axis_tester(Op): ...@@ -104,52 +101,6 @@ def valid_axis_tester(Op):
Op(-4)(*x) Op(-4)(*x)
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = aesara.function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), sp.softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = aesara.function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), sp.softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
valid_axis_tester(Softmax)
class TestSoftmaxWithBias(utt.InferShapeTester): class TestSoftmaxWithBias(utt.InferShapeTester):
def test_basic(self): def test_basic(self):
def f(a, b): def f(a, b):
...@@ -217,160 +168,6 @@ class TestSoftmaxWithBias(utt.InferShapeTester): ...@@ -217,160 +168,6 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
) )
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return logsoftmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = aesara.function([x], logsoftmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), sp.log_softmax(xv))
def test_vector_grad(self):
def f(a):
return logsoftmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_matrix_perform_and_rewrite(self):
m = config.mode
m = aesara.compile.get_mode(m)
m.check_isfinite = False
x, y = matrices("xy")
# regular softmax and crossentropy
sm = softmax(x)
cm = categorical_crossentropy(sm, y)
# numerically stable log-softmax with crossentropy
logsm = logsoftmax(x)
sm2 = exp(logsm) # just used to show equivalence with sm
cm2 = -at_sum(y * logsm, axis=1)
grad_node = grad(cm2.mean(), x)
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
b = np.eye(5, 10).astype(config.floatX)
# show equivalence of softmax and exponentiated numerically stable
# log-softmax
f1 = aesara.function([x], [sm, sm2])
sm_, sm2_ = f1(a)
utt.assert_allclose(sm_, sm2_)
# now show that the two versions result in the same crossentropy cost
# this indicates that the forward function does provide some numerical
# stability
f2 = aesara.function([x, y], [cm, cm2], mode=m)
cm_, cm2_ = f2(a, b)
utt.assert_allclose(cm_, cm2_)
# now, show that in the standard softmax case the gradients blow up
# while in the log-softmax case they don't
f3 = aesara.function([x, y], [grad_node])
grad_ = f3(a, b)
assert not np.any(np.isnan(grad_))
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = aesara.function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = aesara.compile.get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = aesara.shared(a)
f = aesara.function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == softmax_grad_legacy
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = softmax_grad_legacy(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_valid_axis(self):
valid_axis_tester(LogSoftmax)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
valid_axis_tester(SoftmaxGrad)
class TestCrossEntropySoftmax1Hot: class TestCrossEntropySoftmax1Hot:
def test_basic(self): def test_basic(self):
y_idx = [0, 1, 3] y_idx = [0, 1, 3]
...@@ -1202,27 +999,6 @@ def test_grad_softmax_grad(): ...@@ -1202,27 +999,6 @@ def test_grad_softmax_grad():
utt.verify_grad(f, [rng.random((3, 4))]) utt.verify_grad(f, [rng.random((3, 4))])
def test_stabilize_log_softmax():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_relu(): def test_relu():
x = matrix("x") x = matrix("x")
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
...@@ -13,7 +13,7 @@ from aesara import pprint, shared ...@@ -13,7 +13,7 @@ from aesara import pprint, shared
from aesara.compile import optdb from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.mode import OPT_FAST_RUN, Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.basic import Apply, Constant, equal_computations
...@@ -33,7 +33,15 @@ from aesara.tensor.basic import Alloc, join, switch ...@@ -33,7 +33,15 @@ from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj from aesara.tensor.math import (
Dot,
LogSoftmax,
MaxAndArgmax,
Prod,
SoftmaxGrad,
Sum,
_conj,
)
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add from aesara.tensor.math import add
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
...@@ -76,7 +84,17 @@ from aesara.tensor.math import minimum, mul, neg, neq ...@@ -76,7 +84,17 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as at_pow from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, rad2deg, reciprocal from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as at_round from aesara.tensor.math import round as at_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import (
sgn,
sigmoid,
sin,
sinh,
softmax,
softplus,
sqr,
sqrt,
sub,
)
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
...@@ -4578,6 +4596,76 @@ class TestSigmoidUtils: ...@@ -4578,6 +4596,76 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None assert is_1pexp(1 + 2 * exp_op(x), False) is None
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization(): def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize") mode = Mode("py").including("stabilize")
...@@ -4639,3 +4727,24 @@ def test_deprecations(): ...@@ -4639,3 +4727,24 @@ def test_deprecations():
"""Make sure we can import from deprecated modules.""" """Make sure we can import from deprecated modules."""
with pytest.deprecated_call(): with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811 from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
...@@ -8,7 +8,9 @@ from itertools import product ...@@ -8,7 +8,9 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
from scipy.special import softmax as scipy_softmax
import aesara.scalar as aes import aesara.scalar as aes
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
...@@ -34,11 +36,14 @@ from aesara.tensor.elemwise import CAReduce, Elemwise ...@@ -34,11 +36,14 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from aesara.tensor.math import ( from aesara.tensor.math import (
Argmax, Argmax,
Dot, Dot,
LogSoftmax,
MatMul, MatMul,
MaxAndArgmax, MaxAndArgmax,
Mean, Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum, Sum,
_allclose, _allclose,
_dot, _dot,
...@@ -79,6 +84,7 @@ from aesara.tensor.math import ( ...@@ -79,6 +84,7 @@ from aesara.tensor.math import (
log1p, log1p,
log2, log2,
log10, log10,
log_softmax,
logaddexp, logaddexp,
logsumexp, logsumexp,
matmul, matmul,
...@@ -104,6 +110,7 @@ from aesara.tensor.math import ( ...@@ -104,6 +110,7 @@ from aesara.tensor.math import (
sin, sin,
sinh, sinh,
smallest, smallest,
softmax,
sqr, sqr,
sqrt, sqrt,
sub, sub,
...@@ -3521,3 +3528,129 @@ class TestMatMul(utt.InferShapeTester): ...@@ -3521,3 +3528,129 @@ class TestMatMul(utt.InferShapeTester):
[x1, x2], [x1, x2],
self.op_class, self.op_class,
) )
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
...@@ -385,7 +385,7 @@ class TestRopLop(RopLopChecker): ...@@ -385,7 +385,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self): def test_softmax(self):
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape) self.check_rop_lop(aesara.tensor.math.softmax(self.x), self.in_shape)
def test_alloc(self): def test_alloc(self):
# Alloc of the sum of x into a vector # Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论