提交 ff1a3a9d authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Move `Softmax`, `LogSoftmax`, `SoftmaxGrad` to new `aesara.tensor.special`

上级 e2202bc7
......@@ -3,7 +3,7 @@ import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise)
......
......@@ -38,13 +38,8 @@ from aesara.scalar.basic import (
from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax,
MulWithoutZeros,
Softmax,
SoftmaxGrad,
)
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@singledispatch
......
......@@ -113,6 +113,7 @@ import aesara.tensor.rewriting
# isort: off
from aesara.tensor import linalg # noqa
from aesara.tensor import special
# For backward compatibility
from aesara.tensor import nlinalg # noqa
......
import builtins
import warnings
from textwrap import dedent
from typing import TYPE_CHECKING, Optional
import numpy as np
import scipy
from aesara import config, printing
from aesara import scalar as aes
......@@ -2990,766 +2988,6 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
return MatMul(dtype=dtype)(x1, x2)
class SoftmaxGrad(COp):
"""
Gradient wrt x of the Softmax Op.
"""
nin = 2
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm):
dy = as_tensor_variable(dy)
sm = as_tensor_variable(sm)
if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
raise ValueError(
f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
)
return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage):
dy, sm = input_storage
dy_times_sm = dy * sm
dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
output_storage[0][0] = dx
def grad(self, inp, grads):
dy, sm = inp
(g,) = grads
tmp = g + neg(sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm
tmp2 = sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2
return g_dy, g_sm
def infer_shape(self, fgraph, node, shape):
return [shape[1]]
def c_code_cache_version(self):
return (4,)
def c_code(self, node, name, inp, out, sub):
dy, sm = inp
(dx,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[3];
npy_uint32 op_flags[3];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int sm_ndim = PyArray_NDIM({sm});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
(PyArray_TYPE({dy}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
(PyArray_TYPE({sm}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if (axis < 0) axis = sm_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
{fail};
}}
if (({dx} == NULL)
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
{{
Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
PyArray_DIMS({sm}),
PyArray_TYPE({sm}));
if (!{dx})
{{
PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
{fail};
}}
}}
// Create numpy iterator
op[0] = {dy};
op[1] = {sm};
op[2] = {dx};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READONLY;
op_flags[2] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
3,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
class Softmax(COp):
r"""
Softmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}`
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.softmax(x, axis=self.axis)
def L_op(self, inp, outputs, grads):
(x,) = inp
(g_sm,) = grads
return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.L_op(inputs, [self(*inputs)], eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (4,)
UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return Softmax(axis=axis)(c)
class LogSoftmax(COp):
r"""
LogSoftmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\e^{(\mathbf{x}_j - log{\sum_{k=1}^K e^{\mathbf{x}_k})}}
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.log_softmax(x, axis=self.axis)
def grad(self, inp, grads):
(x,) = inp
sm = Softmax(axis=self.axis)(x)
return [grads[0] - sum(grads[0], axis=self.axis, keepdims=True) * sm]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.grad(inputs, eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator");
{fail}
}}
// LogSoftmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm})((*x_ptr) - max);
sum_exp_xdev += exp(*sm_ptr);
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Subtract log(sum(exp(xdev)))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr -= log_sum_exp_xdev;
}} while(get_next(iter));
}}
// LogSoftmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max);
sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]);
}}
// Subtract log(sum(exp(xdev))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (1,)
def logsoftmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)
log_softmax = logsoftmax # scipy name
__all__ = [
"max_and_argmax",
"max",
......@@ -3878,18 +3116,11 @@ __all__ = [
"power",
"logaddexp",
"logsumexp",
"softmax",
"log_softmax",
]
DEPRECATED_NAMES = [
("abs_", "`abs_` is deprecated; use `abs` instead.", abs),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
(
"logsoftmax",
"`logsoftmax` is deprecated; use `reciprocal` instead.",
log_softmax,
),
]
......
......@@ -24,10 +24,7 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique
from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax,
Softmax,
SoftmaxGrad,
Sum,
add,
dot,
......@@ -35,13 +32,11 @@ from aesara.tensor.math import (
exp,
expm1,
log,
log_softmax,
max_and_argmax,
mul,
neg,
or_,
sigmoid,
softmax,
softplus,
)
from aesara.tensor.math import sum as at_sum
......@@ -54,15 +49,9 @@ from aesara.tensor.rewriting.basic import (
)
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.tensor.shape import Shape, shape_padleft
from aesara.tensor.special import Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from aesara.tensor.type import (
TensorType,
discrete_dtypes,
float_dtypes,
integer_dtypes,
values_eq_approx_remove_inf,
values_eq_approx_remove_nan,
)
from aesara.tensor.type import TensorType, discrete_dtypes, float_dtypes, integer_dtypes
class SoftmaxWithBias(COp):
......@@ -327,71 +316,6 @@ softmax_grad_legacy = SoftmaxGrad(axis=-1)
softmax_legacy = Softmax(axis=-1)
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
@register_specialize("fast_compile")
@node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node):
......@@ -2211,12 +2135,12 @@ def confusion_matrix(actual, pred):
DEPRECATED_NAMES = [
(
"softmax",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.math.softmax`.",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.special.softmax`.",
softmax,
),
(
"logsoftmax",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.math.logsoftmax`.",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.special.log_softmax`.",
log_softmax,
),
]
......
......@@ -3,5 +3,6 @@ import aesara.tensor.rewriting.elemwise
import aesara.tensor.rewriting.extra_ops
import aesara.tensor.rewriting.math
import aesara.tensor.rewriting.shape
import aesara.tensor.rewriting.special
import aesara.tensor.rewriting.subtensor
import aesara.tensor.rewriting.uncanonicalize
from aesara.tensor.rewriting.basic import (
register_specialize,
)
from aesara import scalar as aes
from aesara.tensor.math import true_div, exp, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.graph.rewriting.basic import node_rewriter, copy_stack_trace
from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.elemwise import Elemwise, DimShuffle
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([Elemwise])
def local_logsoftmax(fgraph, node):
"""
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
Note: only forward pass is affected
"""
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
ret = new_op(inVars)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
if (
isinstance(node.op, SoftmaxGrad)
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[0].owner.op == true_div
and node.inputs[0].owner.inputs[0].owner is not None
and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
# get parameters from unoptimized op
grads, sm = node.inputs[0].owner.inputs
ret = grads - at_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret)
return [ret]
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
if not numerator.type.dtype.startswith("float"):
continue
if not (numerator.owner and numerator.owner.op == exp):
continue
matching_denom = None
for denominator in denominators:
# Division with dimshuffle
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
ds_order = denominator.owner.op.new_order
# Check that at most only one dimension is being reintroduced by
# a dimshuffle. The cases where all dimensions are reintroduced
# after a complete sum reduction end up in the else branch
if ds_order.count("x") != 1:
continue
# Check that dimshuffle does not change order of original dims
ds_order_without_x = tuple(dim for dim in ds_order if dim != "x")
if tuple(sorted(ds_order_without_x)) != ds_order_without_x:
continue
new_dim = ds_order.index("x")
z = denominator.owner.inputs[0]
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Check that reintroduced dim was the one reduced
if (
(sum_axis is not None)
and (len(sum_axis) == 1)
and (sum_axis[0] == new_dim)
):
if z.owner.inputs[0] is numerator:
(sum_axis,) = sum_axis
matching_denom = denominator
break
# Division without dimshuffle
else:
z = denominator
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Filter out partial summations over more than one axis
# The cases where all axis of summation are explicitly given
# as in `sum(matrix, axis=(0, 1))` are eventually rewritten
# to `sum(matrix)` and this branch is not a blocker
if sum_axis is not None and len(sum_axis) != 1:
continue
if z.owner.inputs[0] is numerator:
if sum_axis is not None:
(sum_axis,) = sum_axis
matching_denom = denominator
break
if matching_denom:
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
copy_stack_trace(numerator, softmax)
numerators.remove(numerator)
denominators.remove(matching_denom)
numerators.append(softmax)
return numerators, denominators
local_mul_canonizer.add_simplifier(softmax_simplifier, "softmax_simplifier")
import warnings
from textwrap import dedent
import numpy as np
import scipy
from aesara.graph.basic import Apply
from aesara.link.c.op import COp
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import neg, sum
class SoftmaxGrad(COp):
"""
Gradient wrt x of the Softmax Op.
"""
nin = 2
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm):
dy = as_tensor_variable(dy)
sm = as_tensor_variable(sm)
if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
raise ValueError(
f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
)
return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage):
dy, sm = input_storage
dy_times_sm = dy * sm
dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
output_storage[0][0] = dx
def grad(self, inp, grads):
dy, sm = inp
(g,) = grads
tmp = g + neg(sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm
tmp2 = sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2
return g_dy, g_sm
def infer_shape(self, fgraph, node, shape):
return [shape[1]]
def c_code_cache_version(self):
return (4,)
def c_code(self, node, name, inp, out, sub):
dy, sm = inp
(dx,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[3];
npy_uint32 op_flags[3];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int sm_ndim = PyArray_NDIM({sm});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
(PyArray_TYPE({dy}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
(PyArray_TYPE({sm}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if (axis < 0) axis = sm_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
{fail};
}}
if (({dx} == NULL)
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
{{
Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
PyArray_DIMS({sm}),
PyArray_TYPE({sm}));
if (!{dx})
{{
PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
{fail};
}}
}}
// Create numpy iterator
op[0] = {dy};
op[1] = {sm};
op[2] = {dx};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READONLY;
op_flags[2] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
3,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
class Softmax(COp):
r"""
Softmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}`
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.softmax(x, axis=self.axis)
def L_op(self, inp, outputs, grads):
(x,) = inp
(g_sm,) = grads
return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.L_op(inputs, [self(*inputs)], eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (4,)
UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return Softmax(axis=axis)(c)
class LogSoftmax(COp):
r"""
LogSoftmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\e^{(\mathbf{x}_j - log{\sum_{k=1}^K e^{\mathbf{x}_k})}}
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.log_softmax(x, axis=self.axis)
def grad(self, inp, grads):
(x,) = inp
sm = Softmax(axis=self.axis)(x)
return [grads[0] - sum(grads[0], axis=self.axis, keepdims=True) * sm]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.grad(inputs, eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator");
{fail}
}}
// LogSoftmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm})((*x_ptr) - max);
sum_exp_xdev += exp(*sm_ptr);
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Subtract log(sum(exp(xdev)))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr -= log_sum_exp_xdev;
}} while(get_next(iter));
}}
// LogSoftmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max);
sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]);
}}
// Subtract log(sum(exp(xdev))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (1,)
def log_softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)
__all__ = [
"softmax",
"log_softmax",
]
......@@ -104,7 +104,7 @@
"\n",
"wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))\n",
"by = th.shared(np.zeros(noutputs), borrow=True)\n",
"y = at.math.softmax(at.dot(h, wy) + by)\n",
"y = at.special.softmax(at.dot(h, wy) + by)\n",
"\n",
"predict = th.function([x], y)"
]
......
......@@ -67,7 +67,7 @@ hidden layer and a softmax output layer.
wy = th.shared(rng.normal(0, 1, (nhiddens, noutputs)))
by = th.shared(np.zeros(noutputs), borrow=True)
y = at.math.softmax(at.dot(h, wy) + by)
y = at.special.softmax(at.dot(h, wy) + by)
predict = th.function([x], y)
......
......@@ -3,6 +3,7 @@ import numpy as np
import aesara.tensor as at
from aesara import shared
from aesara.compile.builders import OpFromGraph
from aesara.tensor.special import softmax
from aesara.tensor.type import dmatrix, scalars
......@@ -24,8 +25,7 @@ class Mlp:
wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True)
y = at.softmax(at.dot(h, wy) + by)
y = softmax(at.dot(h, wy) + by)
self.inputs = [x]
self.outputs = [y]
......
......@@ -5,10 +5,10 @@ from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all
from aesara.tensor.math import log_softmax, prod, softmax
from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum
from aesara.tensor.special import SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
......
......@@ -12,19 +12,8 @@ from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import (
All,
Any,
LogSoftmax,
Max,
Mean,
Min,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
)
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
my_multi_out,
......
......@@ -13,7 +13,7 @@ from aesara import pprint, shared
from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_default_mode, get_mode
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations
......@@ -33,15 +33,7 @@ from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import (
Dot,
LogSoftmax,
MaxAndArgmax,
Prod,
SoftmaxGrad,
Sum,
_conj,
)
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add
from aesara.tensor.math import all as at_all
......@@ -84,17 +76,7 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as at_round
from aesara.tensor.math import (
sgn,
sigmoid,
sin,
sinh,
softmax,
softplus,
sqr,
sqrt,
sub,
)
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
......@@ -4596,97 +4578,6 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_local_logit_sigmoid():
"""Test that graphs of the form ``logit(sigmoid(x))`` and ``sigmoid(logit(x))`` get rewritten to ``x``."""
......@@ -4727,24 +4618,3 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara import shared
from aesara.compile import optdb
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.tensor.math import add, exp, log, true_div
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
from aesara.tensor.type import matrix
from tests import unittest_tools as utt
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into
a softmax Op.
"""
rng = np.random.default_rng(utt.fetch_seed())
x = aesara.shared(rng.normal(size=(3, 4)))
def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
def f(inputs):
y = softmax_graph(x)
return aesara.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.random((3, 4))])
......@@ -8,9 +8,7 @@ from itertools import product
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logsumexp as scipy_logsumexp
from scipy.special import softmax as scipy_softmax
import aesara.scalar as aes
from aesara.compile.debugmode import DebugMode
......@@ -36,14 +34,11 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from aesara.tensor.math import (
Argmax,
Dot,
LogSoftmax,
MatMul,
MaxAndArgmax,
Mean,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
_allclose,
_dot,
......@@ -84,7 +79,6 @@ from aesara.tensor.math import (
log1p,
log2,
log10,
log_softmax,
logaddexp,
logsumexp,
matmul,
......@@ -110,7 +104,6 @@ from aesara.tensor.math import (
sin,
sinh,
smallest,
softmax,
sqr,
sqrt,
sub,
......@@ -3528,129 +3521,3 @@ class TestMatMul(utt.InferShapeTester):
[x1, x2],
self.op_class,
)
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
import numpy as np
import pytest
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import softmax as scipy_softmax
from aesara.compile.function import function
from aesara.configdefaults import config
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
from aesara.tensor.type import matrix, tensor3, tensor4, vector
from tests import unittest_tools as utt
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
......@@ -328,7 +328,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self):
self.check_rop_lop(aesara.tensor.math.softmax(self.x), self.in_shape)
self.check_rop_lop(aesara.tensor.special.softmax(self.x), self.in_shape)
def test_alloc(self):
# Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论