提交 ff1a3a9d authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Move `Softmax`, `LogSoftmax`, `SoftmaxGrad` to new `aesara.tensor.special`

上级 e2202bc7
......@@ -3,7 +3,7 @@ import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise)
......
......@@ -38,13 +38,8 @@ from aesara.scalar.basic import (
from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax,
MulWithoutZeros,
Softmax,
SoftmaxGrad,
)
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@singledispatch
......
......@@ -113,6 +113,7 @@ import aesara.tensor.rewriting
# isort: off
from aesara.tensor import linalg # noqa
from aesara.tensor import special
# For backward compatibility
from aesara.tensor import nlinalg # noqa
......
差异被折叠。
......@@ -24,10 +24,7 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique
from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax,
Softmax,
SoftmaxGrad,
Sum,
add,
dot,
......@@ -35,13 +32,11 @@ from aesara.tensor.math import (
exp,
expm1,
log,
log_softmax,
max_and_argmax,
mul,
neg,
or_,
sigmoid,
softmax,
softplus,
)
from aesara.tensor.math import sum as at_sum
......@@ -54,15 +49,9 @@ from aesara.tensor.rewriting.basic import (
)
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.special import Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import (
TensorType,
discrete_dtypes,
float_dtypes,
integer_dtypes,
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)
from aesara.tensor.type import TensorType, discrete_dtypes, float_dtypes, integer_dtypes
class SoftmaxWithBias(COp):
......@@ -327,71 +316,6 @@ softmax_grad_legacy = SoftmaxGrad(axis=-1)
softmax_legacy = Softmax(axis=-1)
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_specialize("fast_compile")
@node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node):
......@@ -2211,12 +2135,12 @@ def confusion_matrix(actual, pred):
DEPRECATED_NAMES = [
(
"softmax",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.math.softmax`.",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.special.softmax`.",
softmax,
),
(
"logsoftmax",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.math.logsoftmax`.",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.special.log_softmax`.",
log_softmax,
),
]
......
......@@ -3,5 +3,6 @@ import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.special
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
from aesara.tensor.rewriting.basic import (
register_specialize,
)
from aesara import scalar as aes
from aesara.tensor.math import true_div, exp, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.graph.rewriting.basic import node_rewriter, copy_stack_trace
from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.elemwise import Elemwise, DimShuffle
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
if not numerator.type.dtype.startswith("float"):
continue
if not (numerator.owner and numerator.owner.op == exp):
continue
matching_denom = None
for denominator in denominators:
# Division with dimshuffle
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
ds_order = denominator.owner.op.new_order
# Check that at most only one dimension is being reintroduced by
# a dimshuffle. The cases where all dimensions are reintroduced
# after a complete sum reduction end up in the else branch
if ds_order.count("x") != 1:
continue
# Check that dimshuffle does not change order of original dims
ds_order_without_x = tuple(dim for dim in ds_order if dim != "x")
if tuple(sorted(ds_order_without_x)) != ds_order_without_x:
continue
new_dim = ds_order.index("x")
z = denominator.owner.inputs[0]
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Check that reintroduced dim was the one reduced
if (
(sum_axis is not None)
and (len(sum_axis) == 1)
and (sum_axis[0] == new_dim)
):
if z.owner.inputs[0] is numerator:
(sum_axis,) = sum_axis
matching_denom = denominator
break
# Division without dimshuffle
else:
z = denominator
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Filter out partial summations over more than one axis
# The cases where all axis of summation are explicitly given
# as in `sum(matrix, axis=(0, 1))` are eventually rewritten
# to `sum(matrix)` and this branch is not a blocker
if sum_axis is not None and len(sum_axis) != 1:
continue
if z.owner.inputs[0] is numerator:
if sum_axis is not None:
(sum_axis,) = sum_axis
matching_denom = denominator
break
if matching_denom:
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
copy_stack_trace(numerator, softmax)
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax)
return numerators, denominators
local_mul_canonizer.add_simplifier(softmax_simplifier, "softmax_simplifier")
差异被折叠。
......@@ -104,7 +104,7 @@
"\n",
"wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))\n",
"by = th.shared(np.zeros(noutputs), borrow=True)\n",
"y = at.math.softmax(at.dot(h, wy) + by)\n",
"y = at.special.softmax(at.dot(h, wy) + by)\n",
"\n",
"predict = th.function([x], y)"
]
......
......@@ -67,7 +67,7 @@ hidden layer and a softmax output layer.
wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))
by = th.shared(np.zeros(noutputs), borrow=True)
y = at.math.softmax(at.dot(h, wy) + by)
y = at.special.softmax(at.dot(h, wy) + by)
predict = th.function([x], y)
......
......@@ -3,6 +3,7 @@ import numpy as np
import aesara.tensor as at
from aesara import shared
from aesara.compile.builders import OpFromGraph
from aesara.tensor.special import softmax
from aesara.tensor.type import dmatrix, scalars
......@@ -24,8 +25,7 @@ class Mlp:
wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True)
y = at.softmax(at.dot(h, wy) + by)
y = softmax(at.dot(h, wy) + by)
self.inputs = [x]
self.outputs = [y]
......
......@@ -5,10 +5,10 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all
from aesara.tensor.math import log_softmax, prod, softmax
from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum
from aesara.tensor.special import SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
......
......@@ -12,19 +12,8 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import (
All,
Any,
LogSoftmax,
Max,
Mean,
Min,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
)
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
my_multi_out,
......
......@@ -13,7 +13,7 @@ from aesara import pprint, shared
from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_default_mode, get_mode
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations
......@@ -33,15 +33,7 @@ from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import (
Dot,
LogSoftmax,
MaxAndArgmax,
Prod,
SoftmaxGrad,
Sum,
_conj,
)
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add
from aesara.tensor.math import all as at_all
......@@ -84,17 +76,7 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as at_round
from aesara.tensor.math import (
sgn,
sigmoid,
sin,
sinh,
softmax,
softplus,
sqr,
sqrt,
sub,
)
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
......@@ -4596,97 +4578,6 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_local_logit_sigmoid():
"""Test that graphs of the form ``logit(sigmoid(x))`` and ``sigmoid(logit(x))`` get rewritten to ``x``."""
......@@ -4727,24 +4618,3 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara import shared
from aesara.compile import optdb
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.tensor.math import add, exp, log, true_div
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
from aesara.tensor.type import matrix
from tests import unittest_tools as utt
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into
a softmax Op.
"""
rng = np.random.default_rng(utt.fetch_seed())
x = aesara.shared(rng.normal(size=(3, 4)))
def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
def f(inputs):
y = softmax_graph(x)
return aesara.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.random((3, 4))])
......@@ -8,9 +8,7 @@ from itertools import product
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logsumexp as scipy_logsumexp
from scipy.special import softmax as scipy_softmax
import aesara.scalar as aes
from aesara.compile.debugmode import DebugMode
......@@ -36,14 +34,11 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from aesara.tensor.math import (
Argmax,
Dot,
LogSoftmax,
MatMul,
MaxAndArgmax,
Mean,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
_allclose,
_dot,
......@@ -84,7 +79,6 @@ from aesara.tensor.math import (
log1p,
log2,
log10,
log_softmax,
logaddexp,
logsumexp,
matmul,
......@@ -110,7 +104,6 @@ from aesara.tensor.math import (
sin,
sinh,
smallest,
softmax,
sqr,
sqrt,
sub,
......@@ -3528,129 +3521,3 @@ class TestMatMul(utt.InferShapeTester):
[x1, x2],
self.op_class,
)
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
import numpy as np
import pytest
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import softmax as scipy_softmax
from aesara.compile.function import function
from aesara.configdefaults import config
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor3, tensor4, vector
from tests import unittest_tools as utt
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
......@@ -328,7 +328,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self):
self.check_rop_lop(aesara.tensor.math.softmax(self.x), self.in_shape)
self.check_rop_lop(aesara.tensor.special.softmax(self.x), self.in_shape)
def test_alloc(self):
# Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论