提交 ff1a3a9d authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Move `Softmax`, `LogSoftmax`, `SoftmaxGrad` to new `aesara.tensor.special`

上级 e2202bc7
...@@ -3,7 +3,7 @@ import jax.numpy as jnp ...@@ -3,7 +3,7 @@ import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
......
...@@ -38,13 +38,8 @@ from aesara.scalar.basic import ( ...@@ -38,13 +38,8 @@ from aesara.scalar.basic import (
from aesara.scalar.basic import add as add_as from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import ( from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
LogSoftmax, from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
MaxAndArgmax,
MulWithoutZeros,
Softmax,
SoftmaxGrad,
)
@singledispatch @singledispatch
......
...@@ -113,6 +113,7 @@ import aesara.tensor.rewriting ...@@ -113,6 +113,7 @@ import aesara.tensor.rewriting
# isort: off # isort: off
from aesara.tensor import linalg # noqa from aesara.tensor import linalg # noqa
from aesara.tensor import special
# For backward compatibility # For backward compatibility
from aesara.tensor import nlinalg # noqa from aesara.tensor import nlinalg # noqa
......
差异被折叠。
...@@ -24,10 +24,7 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise ...@@ -24,10 +24,7 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique from aesara.tensor.extra_ops import Unique
from aesara.tensor.math import ( from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax, MaxAndArgmax,
Softmax,
SoftmaxGrad,
Sum, Sum,
add, add,
dot, dot,
...@@ -35,13 +32,11 @@ from aesara.tensor.math import ( ...@@ -35,13 +32,11 @@ from aesara.tensor.math import (
exp, exp,
expm1, expm1,
log, log,
log_softmax,
max_and_argmax, max_and_argmax,
mul, mul,
neg, neg,
or_, or_,
sigmoid, sigmoid,
softmax,
softplus, softplus,
) )
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
...@@ -54,15 +49,9 @@ from aesara.tensor.rewriting.basic import ( ...@@ -54,15 +49,9 @@ from aesara.tensor.rewriting.basic import (
) )
from aesara.tensor.rewriting.math import local_mul_canonizer from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.tensor.shape import Shape, shape_padleft from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.special import Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import TensorType, discrete_dtypes, float_dtypes, integer_dtypes
TensorType,
discrete_dtypes,
float_dtypes,
integer_dtypes,
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)
class SoftmaxWithBias(COp): class SoftmaxWithBias(COp):
...@@ -327,71 +316,6 @@ softmax_grad_legacy = SoftmaxGrad(axis=-1) ...@@ -327,71 +316,6 @@ softmax_grad_legacy = SoftmaxGrad(axis=-1)
softmax_legacy = Softmax(axis=-1) softmax_legacy = Softmax(axis=-1)
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_specialize("fast_compile") @register_specialize("fast_compile")
@node_rewriter([softmax_legacy]) @node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node): def local_softmax_with_bias(fgraph, node):
...@@ -2211,12 +2135,12 @@ def confusion_matrix(actual, pred): ...@@ -2211,12 +2135,12 @@ def confusion_matrix(actual, pred):
DEPRECATED_NAMES = [ DEPRECATED_NAMES = [
( (
"softmax", "softmax",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.math.softmax`.", "`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.special.softmax`.",
softmax, softmax,
), ),
( (
"logsoftmax", "logsoftmax",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.math.logsoftmax`.", "`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.special.log_softmax`.",
log_softmax, log_softmax,
), ),
] ]
......
...@@ -3,5 +3,6 @@ import aesara.tensor.rewriting.elemwise ...@@ -3,5 +3,6 @@ import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.special
import aesara.tensor.rewriting.subtensor import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize import aesara.tensor.rewriting.uncanonicalize
from aesara.tensor.rewriting.basic import (
register_specialize,
)
from aesara import scalar as aes
from aesara.tensor.math import true_div, exp, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.graph.rewriting.basic import node_rewriter, copy_stack_trace
from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.elemwise import Elemwise, DimShuffle
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
if not numerator.type.dtype.startswith("float"):
continue
if not (numerator.owner and numerator.owner.op == exp):
continue
matching_denom = None
for denominator in denominators:
# Division with dimshuffle
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
ds_order = denominator.owner.op.new_order
# Check that at most only one dimension is being reintroduced by
# a dimshuffle. The cases where all dimensions are reintroduced
# after a complete sum reduction end up in the else branch
if ds_order.count("x") != 1:
continue
# Check that dimshuffle does not change order of original dims
ds_order_without_x = tuple(dim for dim in ds_order if dim != "x")
if tuple(sorted(ds_order_without_x)) != ds_order_without_x:
continue
new_dim = ds_order.index("x")
z = denominator.owner.inputs[0]
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Check that reintroduced dim was the one reduced
if (
(sum_axis is not None)
and (len(sum_axis) == 1)
and (sum_axis[0] == new_dim)
):
if z.owner.inputs[0] is numerator:
(sum_axis,) = sum_axis
matching_denom = denominator
break
# Division without dimshuffle
else:
z = denominator
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Filter out partial summations over more than one axis
# The cases where all axis of summation are explicitly given
# as in `sum(matrix, axis=(0, 1))` are eventually rewritten
# to `sum(matrix)` and this branch is not a blocker
if sum_axis is not None and len(sum_axis) != 1:
continue
if z.owner.inputs[0] is numerator:
if sum_axis is not None:
(sum_axis,) = sum_axis
matching_denom = denominator
break
if matching_denom:
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
copy_stack_trace(numerator, softmax)
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax)
return numerators, denominators
local_mul_canonizer.add_simplifier(softmax_simplifier, "softmax_simplifier")
差异被折叠。
...@@ -104,7 +104,7 @@ ...@@ -104,7 +104,7 @@
"\n", "\n",
"wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))\n", "wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))\n",
"by = th.shared(np.zeros(noutputs), borrow=True)\n", "by = th.shared(np.zeros(noutputs), borrow=True)\n",
"y = at.math.softmax(at.dot(h, wy) + by)\n", "y = at.special.softmax(at.dot(h, wy) + by)\n",
"\n", "\n",
"predict = th.function([x], y)" "predict = th.function([x], y)"
] ]
......
...@@ -67,7 +67,7 @@ hidden layer and a softmax output layer. ...@@ -67,7 +67,7 @@ hidden layer and a softmax output layer.
wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs))) wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))
by = th.shared(np.zeros(noutputs), borrow=True) by = th.shared(np.zeros(noutputs), borrow=True)
y = at.math.softmax(at.dot(h, wy) + by) y = at.special.softmax(at.dot(h, wy) + by)
predict = th.function([x], y) predict = th.function([x], y)
......
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import aesara.tensor as at import aesara.tensor as at
from aesara import shared from aesara import shared
from aesara.compile.builders import OpFromGraph from aesara.compile.builders import OpFromGraph
from aesara.tensor.special import softmax
from aesara.tensor.type import dmatrix, scalars from aesara.tensor.type import dmatrix, scalars
...@@ -24,8 +25,7 @@ class Mlp: ...@@ -24,8 +25,7 @@ class Mlp:
wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs))) wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True) by = shared(np.zeros(noutputs), borrow=True)
y = at.softmax(at.dot(h, wy) + by) y = softmax(at.dot(h, wy) + by)
self.inputs = [x] self.inputs = [x]
self.outputs = [y] self.outputs = [y]
......
...@@ -5,10 +5,10 @@ from aesara.configdefaults import config ...@@ -5,10 +5,10 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import log_softmax, prod, softmax from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.special import SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
......
...@@ -12,19 +12,8 @@ from aesara.compile.sharedvalue import SharedVariable ...@@ -12,19 +12,8 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import ( from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
All, from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Any,
LogSoftmax,
Max,
Mean,
Min,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
)
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
my_multi_out, my_multi_out,
......
...@@ -13,7 +13,7 @@ from aesara import pprint, shared ...@@ -13,7 +13,7 @@ from aesara import pprint, shared
from aesara.compile import optdb from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_default_mode, get_mode from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.basic import Apply, Constant, equal_computations
...@@ -33,15 +33,7 @@ from aesara.tensor.basic import Alloc, join, switch ...@@ -33,15 +33,7 @@ from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import ( from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
Dot,
LogSoftmax,
MaxAndArgmax,
Prod,
SoftmaxGrad,
Sum,
_conj,
)
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add from aesara.tensor.math import add
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
...@@ -84,17 +76,7 @@ from aesara.tensor.math import minimum, mul, neg, neq ...@@ -84,17 +76,7 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as at_pow from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, rad2deg, reciprocal from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as at_round from aesara.tensor.math import round as at_round
from aesara.tensor.math import ( from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
sgn,
sigmoid,
sin,
sinh,
softmax,
softplus,
sqr,
sqrt,
sub,
)
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
...@@ -4596,97 +4578,6 @@ class TestSigmoidUtils: ...@@ -4596,97 +4578,6 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None assert is_1pexp(1 + 2 * exp_op(x), False) is None
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_local_logit_sigmoid(): def test_local_logit_sigmoid():
"""Test that graphs of the form ``logit(sigmoid(x))`` and ``sigmoid(logit(x))`` get rewritten to ``x``.""" """Test that graphs of the form ``logit(sigmoid(x))`` and ``sigmoid(logit(x))`` get rewritten to ``x``."""
...@@ -4727,24 +4618,3 @@ def test_deprecations(): ...@@ -4727,24 +4618,3 @@ def test_deprecations():
"""Make sure we can import from deprecated modules.""" """Make sure we can import from deprecated modules."""
with pytest.deprecated_call(): with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811 from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara import shared
from aesara.compile import optdb
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.tensor.math import add, exp, log, true_div
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
from aesara.tensor.type import matrix
from tests import unittest_tools as utt
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into
a softmax Op.
"""
rng = np.random.default_rng(utt.fetch_seed())
x = aesara.shared(rng.normal(size=(3, 4)))
def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
def f(inputs):
y = softmax_graph(x)
return aesara.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.random((3, 4))])
...@@ -8,9 +8,7 @@ from itertools import product ...@@ -8,9 +8,7 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
from scipy.special import softmax as scipy_softmax
import aesara.scalar as aes import aesara.scalar as aes
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
...@@ -36,14 +34,11 @@ from aesara.tensor.elemwise import CAReduce, Elemwise ...@@ -36,14 +34,11 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from aesara.tensor.math import ( from aesara.tensor.math import (
Argmax, Argmax,
Dot, Dot,
LogSoftmax,
MatMul, MatMul,
MaxAndArgmax, MaxAndArgmax,
Mean, Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum, Sum,
_allclose, _allclose,
_dot, _dot,
...@@ -84,7 +79,6 @@ from aesara.tensor.math import ( ...@@ -84,7 +79,6 @@ from aesara.tensor.math import (
log1p, log1p,
log2, log2,
log10, log10,
log_softmax,
logaddexp, logaddexp,
logsumexp, logsumexp,
matmul, matmul,
...@@ -110,7 +104,6 @@ from aesara.tensor.math import ( ...@@ -110,7 +104,6 @@ from aesara.tensor.math import (
sin, sin,
sinh, sinh,
smallest, smallest,
softmax,
sqr, sqr,
sqrt, sqrt,
sub, sub,
...@@ -3528,129 +3521,3 @@ class TestMatMul(utt.InferShapeTester): ...@@ -3528,129 +3521,3 @@ class TestMatMul(utt.InferShapeTester):
[x1, x2], [x1, x2],
self.op_class, self.op_class,
) )
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
import numpy as np
import pytest
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import softmax as scipy_softmax
from aesara.compile.function import function
from aesara.configdefaults import config
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor3, tensor4, vector
from tests import unittest_tools as utt
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
...@@ -328,7 +328,7 @@ class TestRopLop(RopLopChecker): ...@@ -328,7 +328,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self): def test_softmax(self):
self.check_rop_lop(aesara.tensor.math.softmax(self.x), self.in_shape) self.check_rop_lop(aesara.tensor.special.softmax(self.x), self.in_shape)
def test_alloc(self): def test_alloc(self):
# Alloc of the sum of x into a vector # Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论