提交 c0614c17 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Move `Sotfmax`, `SoftmaxGrad` `LogSoftmax` to `aesara.tensor.math`

上级 f1cc8937
...@@ -3,7 +3,7 @@ import jax.numpy as jnp ...@@ -3,7 +3,7 @@ import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.math import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
......
...@@ -38,8 +38,13 @@ from aesara.scalar.basic import ( ...@@ -38,8 +38,13 @@ from aesara.scalar.basic import (
from aesara.scalar.basic import add as add_as from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros from aesara.tensor.math import (
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad LogSoftmax,
MaxAndArgmax,
MulWithoutZeros,
Softmax,
SoftmaxGrad,
)
@singledispatch @singledispatch
......
import builtins import builtins
import warnings import warnings
from textwrap import dedent
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
import scipy
from aesara import config, printing from aesara import config, printing
from aesara import scalar as aes from aesara import scalar as aes
...@@ -2988,6 +2990,766 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -2988,6 +2990,766 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
return MatMul(dtype=dtype)(x1, x2) return MatMul(dtype=dtype)(x1, x2)
class SoftmaxGrad(COp):
"""
Gradient wrt x of the Softmax Op.
"""
nin = 2
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm):
dy = as_tensor_variable(dy)
sm = as_tensor_variable(sm)
if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
raise ValueError(
f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
)
return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage):
dy, sm = input_storage
dy_times_sm = dy * sm
dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
output_storage[0][0] = dx
def grad(self, inp, grads):
dy, sm = inp
(g,) = grads
tmp = g + neg(sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm
tmp2 = sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2
return g_dy, g_sm
def infer_shape(self, fgraph, node, shape):
return [shape[1]]
def c_code_cache_version(self):
return (4,)
def c_code(self, node, name, inp, out, sub):
dy, sm = inp
(dx,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[3];
npy_uint32 op_flags[3];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int sm_ndim = PyArray_NDIM({sm});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
(PyArray_TYPE({dy}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
(PyArray_TYPE({sm}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if (axis < 0) axis = sm_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
{fail};
}}
if (({dx} == NULL)
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
{{
Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
PyArray_DIMS({sm}),
PyArray_TYPE({sm}));
if (!{dx})
{{
PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
{fail};
}}
}}
// Create numpy iterator
op[0] = {dy};
op[1] = {sm};
op[2] = {dx};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READONLY;
op_flags[2] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
3,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
class Softmax(COp):
r"""
Softmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}`
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.softmax(x, axis=self.axis)
def L_op(self, inp, outputs, grads):
(x,) = inp
(g_sm,) = grads
return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.L_op(inputs, [self(*inputs)], eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (4,)
UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return Softmax(axis=axis)(c)
class LogSoftmax(COp):
r"""
LogSoftmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\e^{(\mathbf{x}_j - log{\sum_{k=1}^K e^{\mathbf{x}_k})}}
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.log_softmax(x, axis=self.axis)
def grad(self, inp, grads):
(x,) = inp
sm = Softmax(axis=self.axis)(x)
return [grads[0] - sum(grads[0], axis=self.axis, keepdims=True) * sm]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.grad(inputs, eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator");
{fail}
}}
// LogSoftmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm})((*x_ptr) - max);
sum_exp_xdev += exp(*sm_ptr);
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Subtract log(sum(exp(xdev)))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr -= log_sum_exp_xdev;
}} while(get_next(iter));
}}
// LogSoftmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max);
sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]);
}}
// Subtract log(sum(exp(xdev))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (1,)
def logsoftmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)
log_softmax = logsoftmax # scipy name
__all__ = [ __all__ = [
"max_and_argmax", "max_and_argmax",
"max", "max",
...@@ -3116,11 +3878,18 @@ __all__ = [ ...@@ -3116,11 +3878,18 @@ __all__ = [
"power", "power",
"logaddexp", "logaddexp",
"logsumexp", "logsumexp",
"softmax",
"log_softmax",
] ]
DEPRECATED_NAMES = [ DEPRECATED_NAMES = [
("abs_", "`abs_` is deprecated; use `abs` instead.", abs), ("abs_", "`abs_` is deprecated; use `abs` instead.", abs),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal), ("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
(
"logsoftmax",
"`logsoftmax` is deprecated; use `reciprocal` instead.",
log_softmax,
),
] ]
......
...@@ -6,11 +6,7 @@ Notes ...@@ -6,11 +6,7 @@ Notes
TODO: factor this out into a neural-network toolbox. TODO: factor this out into a neural-network toolbox.
""" """
import warnings
from textwrap import dedent
import numpy as np import numpy as np
import scipy.special
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
...@@ -23,12 +19,15 @@ from aesara.link.c.op import COp ...@@ -23,12 +19,15 @@ from aesara.link.c.op import COp
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp from aesara.scalar import UnaryScalarOp
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.basic import ARange, as_tensor_variable from aesara.tensor.basic import ARange
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.extra_ops import Unique from aesara.tensor.extra_ops import Unique
from aesara.tensor.math import ( from aesara.tensor.math import (
LogSoftmax,
MaxAndArgmax, MaxAndArgmax,
Softmax,
SoftmaxGrad,
Sum, Sum,
add, add,
dot, dot,
...@@ -36,11 +35,13 @@ from aesara.tensor.math import ( ...@@ -36,11 +35,13 @@ from aesara.tensor.math import (
exp, exp,
expm1, expm1,
log, log,
log_softmax,
max_and_argmax, max_and_argmax,
mul, mul,
neg, neg,
or_, or_,
sigmoid, sigmoid,
softmax,
softplus, softplus,
) )
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
...@@ -320,729 +321,12 @@ class SoftmaxWithBias(COp): ...@@ -320,729 +321,12 @@ class SoftmaxWithBias(COp):
softmax_with_bias = SoftmaxWithBias() softmax_with_bias = SoftmaxWithBias()
class SoftmaxGrad(COp):
"""
Gradient wrt x of the Softmax Op.
"""
nin = 2
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm):
dy = at.as_tensor_variable(dy)
sm = at.as_tensor_variable(sm)
if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
raise ValueError(
f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
)
return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage):
dy, sm = input_storage
dy_times_sm = dy * sm
dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
output_storage[0][0] = dx
def grad(self, inp, grads):
dy, sm = inp
(g,) = grads
tmp = g + neg(at_sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm
tmp2 = at_sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2
return g_dy, g_sm
def infer_shape(self, fgraph, node, shape):
return [shape[1]]
def c_code_cache_version(self):
return (4,)
def c_code(self, node, name, inp, out, sub):
dy, sm = inp
(dx,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[3];
npy_uint32 op_flags[3];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int sm_ndim = PyArray_NDIM({sm});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
(PyArray_TYPE({dy}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
(PyArray_TYPE({sm}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if (axis < 0) axis = sm_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
{fail};
}}
if (({dx} == NULL)
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
{{
Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
PyArray_DIMS({sm}),
PyArray_TYPE({sm}));
if (!{dx})
{{
PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
{fail};
}}
}}
// Create numpy iterator
op[0] = {dy};
op[1] = {sm};
op[2] = {dx};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READONLY;
op_flags[2] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
3,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
softmax_grad_legacy = SoftmaxGrad(axis=-1) softmax_grad_legacy = SoftmaxGrad(axis=-1)
class Softmax(COp):
r"""
Softmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\frac{e^{\mathbf{x}_j}}{\sum_{k=1}^K e^{\mathbf{x}_k}}`
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = at.as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.softmax(x, axis=self.axis)
def L_op(self, inp, outputs, grads):
(x,) = inp
(g_sm,) = grads
return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.L_op(inputs, [self(*inputs)], eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (4,)
softmax_legacy = Softmax(axis=-1) softmax_legacy = Softmax(axis=-1)
class LogSoftmax(COp):
r"""
LogSoftmax activation function
:math:`\\varphi(\\mathbf{x})_j =
\\e^{(\mathbf{x}_j - log{\sum_{k=1}^K e^{\mathbf{x}_k})}}
where :math:`K` is the total number of neurons in the layer. This
activation function gets applied row-wise.
"""
nin = 1
nout = 1
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = at.as_tensor_variable(x)
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"LogSoftmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
(z,) = output_storage
z[0] = scipy.special.log_softmax(x, axis=self.axis)
def grad(self, inp, grads):
(x,) = inp
sm = Softmax(axis=self.axis)(x)
return [grads[0] - at_sum(grads[0], axis=self.axis, keepdims=True) * sm]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
# is the same as the grad
if None in eval_points:
return [None]
return self.grad(inputs, eval_points)
def infer_shape(self, fgraph, node, shape):
return shape
def c_headers(self, **kwargs):
return ["<cmath>"]
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc LogSoftmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create LogSoftmax iterator");
{fail}
}}
// LogSoftmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm})((*x_ptr) - max);
sum_exp_xdev += exp(*sm_ptr);
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset LogSoftmax iterator");
{fail}
}}
// Subtract log(sum(exp(xdev)))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr -= log_sum_exp_xdev;
}} while(get_next(iter));
}}
// LogSoftmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove LogSoftmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain LogSoftmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute xdev and sum(exp(xdev))
dtype_{sm} sum_exp_xdev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{x})(x_axis[i * x_axis_stride] - max);
sum_exp_xdev += exp(sm_axis[i * sm_axis_stride]);
}}
// Subtract log(sum(exp(xdev))
dtype_{sm} log_sum_exp_xdev = log(sum_exp_xdev);
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] -= log_sum_exp_xdev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
@staticmethod
def c_code_cache_version():
return (1,)
# This is not registered in stabilize, as it cause some crossentropy # This is not registered in stabilize, as it cause some crossentropy
# optimization to not be inserted. # optimization to not be inserted.
@register_specialize("stabilize", "fast_compile") @register_specialize("stabilize", "fast_compile")
...@@ -1108,47 +392,6 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -1108,47 +392,6 @@ def local_logsoftmax_grad(fgraph, node):
return [ret] return [ret]
UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return Softmax(axis=axis)(c)
def logsoftmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"logsoftmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return LogSoftmax(axis=axis)(c)
@register_specialize("fast_compile") @register_specialize("fast_compile")
@node_rewriter([softmax_legacy]) @node_rewriter([softmax_legacy])
def local_softmax_with_bias(fgraph, node): def local_softmax_with_bias(fgraph, node):
...@@ -2963,3 +2206,33 @@ def confusion_matrix(actual, pred): ...@@ -2963,3 +2206,33 @@ def confusion_matrix(actual, pred):
conf_mat = dot(oneHotA.T, oneHotP) conf_mat = dot(oneHotA.T, oneHotP)
return [conf_mat, order] return [conf_mat, order]
DEPRECATED_NAMES = [
(
"softmax",
"`aesara.tensor.nnet.basic.softmax` has been moved to `aesara.tensor.math.softmax`.",
softmax,
),
(
"logsoftmax",
"`aesara.tensor.nnet.basic.logsoftmax` has been moved to `aesara.tensor.math.logsoftmax`.",
log_softmax,
),
]
def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.
Adapted from https://stackoverflow.com/a/55139609/3006474.
"""
from warnings import warn
for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object
raise AttributeError(f"module {__name__} has no attribute {name}")
...@@ -24,7 +24,7 @@ class Mlp: ...@@ -24,7 +24,7 @@ class Mlp:
wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs))) wy = shared(self.rng.normal(0, 1, (nhiddens, noutputs)))
by = shared(np.zeros(noutputs), borrow=True) by = shared(np.zeros(noutputs), borrow=True)
y = at.nnet.softmax(at.dot(h, wy) + by) y = at.softmax(at.dot(h, wy) + by)
self.inputs = [x] self.inputs = [x]
self.outputs = [y] self.outputs = [y]
......
...@@ -6,10 +6,10 @@ from aesara.graph.fg import FunctionGraph ...@@ -6,10 +6,10 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import nnet as at_nnet from aesara.tensor import nnet as at_nnet
from aesara.tensor.math import SoftmaxGrad
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import prod from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.type import matrix, tensor, vector from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
......
...@@ -6,14 +6,25 @@ import pytest ...@@ -6,14 +6,25 @@ import pytest
import aesara.tensor as at import aesara.tensor as at
import aesara.tensor.inplace as ati import aesara.tensor.inplace as ati
import aesara.tensor.math as aem import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb
from aesara import config from aesara import config
from aesara.compile.ops import deep_copy_op from aesara.compile.ops import deep_copy_op
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from aesara.tensor.math import (
All,
Any,
LogSoftmax,
Max,
Mean,
Min,
Prod,
ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum,
)
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
my_multi_out, my_multi_out,
...@@ -377,7 +388,7 @@ def test_scalar_Elemwise_Clip(): ...@@ -377,7 +388,7 @@ def test_scalar_Elemwise_Clip():
], ],
) )
def test_SoftmaxGrad(dy, sm, axis, exc): def test_SoftmaxGrad(dy, sm, axis, exc):
g = nnetb.SoftmaxGrad(axis=axis)(dy, sm) g = SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
...@@ -413,7 +424,7 @@ def test_SoftmaxGrad(dy, sm, axis, exc): ...@@ -413,7 +424,7 @@ def test_SoftmaxGrad(dy, sm, axis, exc):
], ],
) )
def test_Softmax(x, axis, exc): def test_Softmax(x, axis, exc):
g = nnetb.Softmax(axis=axis)(x) g = Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
...@@ -449,7 +460,7 @@ def test_Softmax(x, axis, exc): ...@@ -449,7 +460,7 @@ def test_Softmax(x, axis, exc):
], ],
) )
def test_LogSoftmax(x, axis, exc): def test_LogSoftmax(x, axis, exc):
g = nnetb.LogSoftmax(axis=axis)(x) g = LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
...@@ -40,7 +40,7 @@ from aesara.scan.basic import scan ...@@ -40,7 +40,7 @@ from aesara.scan.basic import scan
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import dot, mean, sigmoid from aesara.tensor.math import dot, exp, mean, sigmoid
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
from aesara.tensor.nnet import categorical_crossentropy from aesara.tensor.nnet import categorical_crossentropy
...@@ -69,7 +69,6 @@ from aesara.tensor.type import ( ...@@ -69,7 +69,6 @@ from aesara.tensor.type import (
vector, vector,
) )
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.nnet.test_basic import softmax_graph
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -85,6 +84,10 @@ else: ...@@ -85,6 +84,10 @@ else:
type_eps = {"float64": 1e-7, "float32": 3e-3} type_eps = {"float64": 1e-7, "float32": 3e-3}
def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
class multiple_outputs_numeric_grad: class multiple_outputs_numeric_grad:
"""WRITEME""" """WRITEME"""
......
...@@ -24,13 +24,12 @@ from aesara.tensor.math import ( ...@@ -24,13 +24,12 @@ from aesara.tensor.math import (
sigmoid, sigmoid,
) )
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh, true_div from aesara.tensor.math import tanh
from aesara.tensor.nnet.basic import ( from aesara.tensor.nnet.basic import (
CrossentropyCategorical1Hot, CrossentropyCategorical1Hot,
CrossentropyCategorical1HotGrad, CrossentropyCategorical1HotGrad,
CrossentropySoftmax1HotWithBiasDx, CrossentropySoftmax1HotWithBiasDx,
CrossentropySoftmaxArgmax1HotWithBias, CrossentropySoftmaxArgmax1HotWithBias,
LogSoftmax,
Prepend_scalar_constant_to_each_row, Prepend_scalar_constant_to_each_row,
Prepend_scalar_to_each_row, Prepend_scalar_to_each_row,
Softmax, Softmax,
...@@ -46,7 +45,6 @@ from aesara.tensor.nnet.basic import ( ...@@ -46,7 +45,6 @@ from aesara.tensor.nnet.basic import (
crossentropy_softmax_argmax_1hot_with_bias, crossentropy_softmax_argmax_1hot_with_bias,
elu, elu,
h_softmax, h_softmax,
logsoftmax,
relu, relu,
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
...@@ -65,7 +63,6 @@ from aesara.tensor.type import ( ...@@ -65,7 +63,6 @@ from aesara.tensor.type import (
fvector, fvector,
ivector, ivector,
lvector, lvector,
matrices,
matrix, matrix,
scalar, scalar,
tensor3, tensor3,
...@@ -104,52 +101,6 @@ def valid_axis_tester(Op): ...@@ -104,52 +101,6 @@ def valid_axis_tester(Op):
Op(-4)(*x) Op(-4)(*x)
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = aesara.function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), sp.softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = aesara.function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), sp.softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
valid_axis_tester(Softmax)
class TestSoftmaxWithBias(utt.InferShapeTester): class TestSoftmaxWithBias(utt.InferShapeTester):
def test_basic(self): def test_basic(self):
def f(a, b): def f(a, b):
...@@ -217,160 +168,6 @@ class TestSoftmaxWithBias(utt.InferShapeTester): ...@@ -217,160 +168,6 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
) )
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return logsoftmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = aesara.function([x], logsoftmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), sp.log_softmax(xv))
def test_vector_grad(self):
def f(a):
return logsoftmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_matrix_perform_and_rewrite(self):
m = config.mode
m = aesara.compile.get_mode(m)
m.check_isfinite = False
x, y = matrices("xy")
# regular softmax and crossentropy
sm = softmax(x)
cm = categorical_crossentropy(sm, y)
# numerically stable log-softmax with crossentropy
logsm = logsoftmax(x)
sm2 = exp(logsm) # just used to show equivalence with sm
cm2 = -at_sum(y * logsm, axis=1)
grad_node = grad(cm2.mean(), x)
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
b = np.eye(5, 10).astype(config.floatX)
# show equivalence of softmax and exponentiated numerically stable
# log-softmax
f1 = aesara.function([x], [sm, sm2])
sm_, sm2_ = f1(a)
utt.assert_allclose(sm_, sm2_)
# now show that the two versions result in the same crossentropy cost
# this indicates that the forward function does provide some numerical
# stability
f2 = aesara.function([x, y], [cm, cm2], mode=m)
cm_, cm2_ = f2(a, b)
utt.assert_allclose(cm_, cm2_)
# now, show that in the standard softmax case the gradients blow up
# while in the log-softmax case they don't
f3 = aesara.function([x, y], [grad_node])
grad_ = f3(a, b)
assert not np.any(np.isnan(grad_))
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = aesara.function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = aesara.compile.get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = aesara.shared(a)
f = aesara.function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == softmax_grad_legacy
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = softmax_grad_legacy(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
def test_valid_axis(self):
valid_axis_tester(LogSoftmax)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
valid_axis_tester(SoftmaxGrad)
class TestCrossEntropySoftmax1Hot: class TestCrossEntropySoftmax1Hot:
def test_basic(self): def test_basic(self):
y_idx = [0, 1, 3] y_idx = [0, 1, 3]
...@@ -1202,27 +999,6 @@ def test_grad_softmax_grad(): ...@@ -1202,27 +999,6 @@ def test_grad_softmax_grad():
utt.verify_grad(f, [rng.random((3, 4))]) utt.verify_grad(f, [rng.random((3, 4))])
def test_stabilize_log_softmax():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_relu(): def test_relu():
x = matrix("x") x = matrix("x")
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
...@@ -13,7 +13,7 @@ from aesara import pprint, shared ...@@ -13,7 +13,7 @@ from aesara import pprint, shared
from aesara.compile import optdb from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.mode import OPT_FAST_RUN, Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.basic import Apply, Constant, equal_computations
...@@ -33,7 +33,15 @@ from aesara.tensor.basic import Alloc, join, switch ...@@ -33,7 +33,15 @@ from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj from aesara.tensor.math import (
Dot,
LogSoftmax,
MaxAndArgmax,
Prod,
SoftmaxGrad,
Sum,
_conj,
)
from aesara.tensor.math import abs as at_abs from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import add from aesara.tensor.math import add
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
...@@ -76,7 +84,17 @@ from aesara.tensor.math import minimum, mul, neg, neq ...@@ -76,7 +84,17 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as at_pow from aesara.tensor.math import pow as at_pow
from aesara.tensor.math import prod, rad2deg, reciprocal from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as at_round from aesara.tensor.math import round as at_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import (
sgn,
sigmoid,
sin,
sinh,
softmax,
softplus,
sqr,
sqrt,
sub,
)
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift
...@@ -4578,6 +4596,76 @@ class TestSigmoidUtils: ...@@ -4578,6 +4596,76 @@ class TestSigmoidUtils:
assert is_1pexp(1 + 2 * exp_op(x), False) is None assert is_1pexp(1 + 2 * exp_op(x), False) is None
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
"""Test the `Logsoftmax` substitution.
Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that
only the forward pass is checked (i.e., doesn't check the gradient)
"""
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
"""Test the `Logsoftmax`'s grad substitution.
Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s
grad and that the new operation does not explode for big inputs.
Note that only the grad is checked.
"""
m = config.mode
m = get_mode(m)
m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non
# rewritten case
rng = np.random.default_rng(utt.fetch_seed())
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x):
sm = softmax(x, axis=axis)
logsm = log(sm)
return logsm
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
"""
Checks that the gradient of an expression similar to a ``log(softmax)`` but
with a different elemwise operation than true_div is not rewritten.
"""
x = matrix("x")
y = log(softmax(x))
g = aesara.tensor.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == SoftmaxGrad(axis=-1)
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization(): def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize") mode = Mode("py").including("stabilize")
...@@ -4639,3 +4727,24 @@ def test_deprecations(): ...@@ -4639,3 +4727,24 @@ def test_deprecations():
"""Make sure we can import from deprecated modules.""" """Make sure we can import from deprecated modules."""
with pytest.deprecated_call(): with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811 from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
...@@ -8,7 +8,9 @@ from itertools import product ...@@ -8,7 +8,9 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import log_softmax as scipy_log_softmax
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
from scipy.special import softmax as scipy_softmax
import aesara.scalar as aes import aesara.scalar as aes
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
...@@ -34,11 +36,14 @@ from aesara.tensor.elemwise import CAReduce, Elemwise ...@@ -34,11 +36,14 @@ from aesara.tensor.elemwise import CAReduce, Elemwise
from aesara.tensor.math import ( from aesara.tensor.math import (
Argmax, Argmax,
Dot, Dot,
LogSoftmax,
MatMul, MatMul,
MaxAndArgmax, MaxAndArgmax,
Mean, Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Softmax,
SoftmaxGrad,
Sum, Sum,
_allclose, _allclose,
_dot, _dot,
...@@ -79,6 +84,7 @@ from aesara.tensor.math import ( ...@@ -79,6 +84,7 @@ from aesara.tensor.math import (
log1p, log1p,
log2, log2,
log10, log10,
log_softmax,
logaddexp, logaddexp,
logsumexp, logsumexp,
matmul, matmul,
...@@ -104,6 +110,7 @@ from aesara.tensor.math import ( ...@@ -104,6 +110,7 @@ from aesara.tensor.math import (
sin, sin,
sinh, sinh,
smallest, smallest,
softmax,
sqr, sqr,
sqrt, sqrt,
sub, sub,
...@@ -3521,3 +3528,129 @@ class TestMatMul(utt.InferShapeTester): ...@@ -3521,3 +3528,129 @@ class TestMatMul(utt.InferShapeTester):
[x1, x2], [x1, x2],
self.op_class, self.op_class,
) )
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((2, 3, 4, 5)).astype(config.floatX)
f = function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), scipy_softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector_perform(self):
x = vector()
f = function([x], softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4))])
def test_valid_axis(self):
with pytest.raises(TypeError):
Softmax(1.5)
x = [tensor3()] * LogSoftmax.nin
Softmax(2)(*x)
Softmax(-3)(*x)
with pytest.raises(ValueError):
Softmax(3)(*x)
with pytest.raises(ValueError):
Softmax(-4)(*x)
class TestLogSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_matrix_grad(self, axis, column):
def f(a):
return log_softmax(a, axis=axis)[:, column]
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((3, 4))])
def test_vector_perform(self):
x = vector()
f = function([x], log_softmax(x, axis=None))
rng = np.random.default_rng(utt.fetch_seed())
xv = rng.standard_normal((6,)).astype(config.floatX)
assert np.allclose(f(xv), scipy_log_softmax(xv))
def test_vector_grad(self):
def f(a):
return log_softmax(a, axis=None)
rng = np.random.default_rng(utt.fetch_seed())
utt.verify_grad(f, [rng.random((4,))])
def test_valid_axis(self):
with pytest.raises(TypeError):
LogSoftmax(1.5)
x = [tensor3()] * LogSoftmax.nin
LogSoftmax(2)(*x)
LogSoftmax(-3)(*x)
with pytest.raises(ValueError):
LogSoftmax(3)(*x)
with pytest.raises(ValueError):
LogSoftmax(-4)(*x)
class TestSoftmaxGrad(utt.InferShapeTester):
def test_infer_shape(self):
admat = matrix()
bdmat = matrix()
rng = np.random.default_rng(utt.fetch_seed())
admat_val = rng.random((3, 4)).astype(config.floatX)
bdmat_val = rng.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
with pytest.raises(TypeError):
SoftmaxGrad(1.5)
x = [tensor3()] * SoftmaxGrad.nin
SoftmaxGrad(2)(*x)
SoftmaxGrad(-3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(3)(*x)
with pytest.raises(ValueError):
SoftmaxGrad(-4)(*x)
...@@ -385,7 +385,7 @@ class TestRopLop(RopLopChecker): ...@@ -385,7 +385,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self): def test_softmax(self):
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape) self.check_rop_lop(aesara.tensor.math.softmax(self.x), self.in_shape)
def test_alloc(self): def test_alloc(self):
# Alloc of the sum of x into a vector # Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论