提交 58cb5c30 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to Softmax and SoftmaxGrad Ops

上级 c6c85acb
......@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x):
return jax.nn.softmax(x)
return jax.nn.softmax(x, axis=axis)
return softmax
......
......@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype)
if axis is not None:
reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit
def softmax(x):
z = np.expand_dims(reduce_max(x), -1)
z = reduce_max(x)
e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1)
w = reduce_sum(e_x)
sm = e_x / w
return sm
......
......@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import (
selu,
sigmoid_binary_crossentropy,
softmax,
softmax_grad,
softmax_grad_legacy,
softmax_graph,
softmax_op,
softmax_legacy,
softmax_simplifier,
softmax_with_bias,
softsign,
......
......@@ -14,8 +14,10 @@ revisited later when all the intermediate part are on the GPU.
"""
import warnings
from textwrap import dedent
import numpy as np
import scipy.special
import aesara
from aesara import scalar as aes
......@@ -140,7 +142,7 @@ class SoftmaxWithBias(COp):
if isinstance(g_sm.type, DisconnectedType):
return [DisconnectedType()(), DisconnectedType()()]
dx = softmax_grad(g_sm, outputs[0])
dx = softmax_grad_legacy(g_sm, outputs[0])
db = aet_sum(dx, axis=0)
return dx, db
......@@ -339,36 +341,39 @@ class SoftmaxGrad(COp):
nin = 2
nout = 1
__props__ = ()
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm):
dy = aet.as_tensor_variable(dy)
sm = aet.as_tensor_variable(sm)
if dy.type.ndim not in (1, 2) or dy.type.dtype not in float_dtypes:
raise ValueError("dy must be 1-d or 2-d tensor of floats. Got ", dy.type)
if dy.ndim == 1:
dy = shape_padleft(dy, n_ones=1)
if sm.ndim == 1:
sm = shape_padleft(sm, n_ones=1)
if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
raise ValueError(
f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
)
return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage):
dy, sm = input_storage
dx = np.zeros_like(sm)
# dx[i,j] = - (\sum_k dy[i,k] sm[i,k]) sm[i,j] + dy[i,j] sm[i,j]
for i in range(sm.shape[0]):
dy_times_sm_i = dy[i] * sm[i]
dx[i] = dy_times_sm_i - sum(dy_times_sm_i) * sm[i]
dy_times_sm = dy * sm
dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
output_storage[0][0] = dx
def grad(self, inp, grads):
dy, sm = inp
(g,) = grads
tmp = g + neg(aet_sum(g * sm, axis=1).dimshuffle((0, "x")))
tmp = g + neg(aet_sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm
tmp2 = aet_sum(dy * sm, axis=1).dimshuffle((0, "x"))
tmp2 = aet_sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2
return g_dy, g_sm
......@@ -377,79 +382,184 @@ class SoftmaxGrad(COp):
return [shape[1]]
def c_code_cache_version(self):
return (3,)
return (4,)
def c_code(self, node, name, inp, out, sub):
dy, sm = inp
(dx,) = out
return """
if ((PyArray_TYPE(%(dy)s) != NPY_DOUBLE) &&
(PyArray_TYPE(%(dy)s) != NPY_FLOAT))
{
PyErr_SetString(PyExc_TypeError,
"types should be float or float64");
%(fail)s;
}
if ((PyArray_TYPE(%(sm)s) != NPY_DOUBLE) &&
(PyArray_TYPE(%(sm)s) != NPY_FLOAT))
{
PyErr_SetString(PyExc_TypeError,
"types should be float or float64");
%(fail)s;
}
if ((PyArray_NDIM(%(dy)s) != 2)
|| (PyArray_NDIM(%(sm)s) != 2))
{
PyErr_SetString(PyExc_ValueError, "rank error");
%(fail)s;
}
if (PyArray_DIMS(%(dy)s)[0] != PyArray_DIMS(%(sm)s)[0])
{
PyErr_SetString(PyExc_ValueError, "dy.shape[0] != sm.shape[0]");
%(fail)s;
}
if ((NULL == %(dx)s)
|| (PyArray_DIMS(%(dx)s)[0] != PyArray_DIMS(%(sm)s)[0])
|| (PyArray_DIMS(%(dx)s)[1] != PyArray_DIMS(%(sm)s)[1]))
{
Py_XDECREF(%(dx)s);
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(%(sm)s),
PyArray_TYPE(%(sm)s));
if (!%(dx)s)
{
PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output");
%(fail)s;
}
}
for (size_t i = 0; i < PyArray_DIMS(%(dx)s)[0]; ++i)
{
const dtype_%(dy)s* __restrict__ dy_i = (dtype_%(dy)s*) (PyArray_BYTES(%(dy)s) + PyArray_STRIDES(%(dy)s)[0] * i);
npy_intp Sdy = PyArray_STRIDES(%(dy)s)[1]/sizeof(dtype_%(dy)s);
const dtype_%(sm)s* __restrict__ sm_i = (dtype_%(sm)s*) (PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
dtype_%(dx) s* __restrict__ dx_i = (dtype_%(dx)s*) (PyArray_BYTES(%(dx)s) + PyArray_STRIDES(%(dx)s)[0] * i);
npy_intp Sdx = PyArray_STRIDES(%(dx)s)[1]/sizeof(dtype_%(dx)s);
double sum_dy_times_sm = 0.;
for (size_t j = 0; j < PyArray_DIMS(%(dx)s)[1]; ++j)
{
dx_i[j * Sdx] = dy_i[j * Sdy] * sm_i[j * Ssm];
sum_dy_times_sm += dx_i[j * Sdx];
}
for (size_t j = 0; j < PyArray_DIMS(%(dx)s)[1]; ++j)
{
dx_i[j * Sdx] -= sum_dy_times_sm * sm_i[j * Ssm];
}
}
""" % dict(
locals(), **sub
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
return dedent(
f"""
PyArrayObject* op[3];
npy_uint32 op_flags[3];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int sm_ndim = PyArray_NDIM({sm});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
(PyArray_TYPE({dy}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
(PyArray_TYPE({sm}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "types should be float or float64");
{fail};
}}
if (axis < 0) axis = sm_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
{fail};
}}
if (({dx} == NULL)
|| !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
{{
Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
PyArray_DIMS({sm}),
PyArray_TYPE({sm}));
if (!{dx})
{{
PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
{fail};
}}
}}
// Create numpy iterator
op[0] = {dy};
op[1] = {sm};
op[2] = {dx};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READONLY;
op_flags[2] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
3,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
)
softmax_grad = SoftmaxGrad()
softmax_grad_legacy = SoftmaxGrad(axis=-1)
class Softmax(COp):
......@@ -464,34 +574,32 @@ class Softmax(COp):
nin = 1
nout = 1
__props__ = ()
__props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x):
x = aet.as_tensor_variable(x)
if x.type.ndim not in (1, 2) or x.type.dtype not in float_dtypes:
raise ValueError(f"x must be 1-d or 2-d tensor of floats. Got {x.type}")
if x.ndim == 1:
warnings.warn(
"If x is a vector, Softmax will not automatically pad x "
"anymore in next releases. If you need it, please do it manually. The "
"vector case is gonna be supported soon and the output will be a vector.",
category=PendingDeprecationWarning,
stacklevel=4,
if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
raise ValueError(
f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
)
x = shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage):
(x,) = input_storage
e_x = np.exp(x - x.max(axis=1)[:, None])
sm = e_x / e_x.sum(axis=1)[:, None]
output_storage[0][0] = sm
(z,) = output_storage
z[0] = scipy.special.softmax(x, axis=self.axis)
def L_op(self, inp, outputs, grads):
(x,) = inp
(g_sm,) = grads
return [softmax_grad(g_sm, outputs[0])]
return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op
......@@ -506,151 +614,202 @@ class Softmax(COp):
def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"]
@staticmethod
def c_code_template(dtype):
# this implementation was lifted from
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output
# TODO: set error messages for failures in this code
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """
npy_intp* Nx = PyArray_DIMS(%(x)s);
npy_intp Sx1 = 0;
npy_intp Ssm1 = 0;
if (PyArray_NDIM(%(x)s) != 2)
{
PyErr_SetString(PyExc_ValueError, "not a 2d tensor");
%(fail)s;
}
if ((PyArray_TYPE(%(x)s) != NPY_DOUBLE) &&
(PyArray_TYPE(%(x)s) != NPY_FLOAT))
{
PyErr_SetString(PyExc_TypeError, "not a float");
%(fail)s;
}
if ((NULL == %(sm)s)
|| (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0])
|| (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1]))
{
Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
PyArray_TYPE(%(x)s));
if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output");
%(fail)s
}
}
Sx1 = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s);
Ssm1 = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s);
"""
begin_row_loop = """
for (size_t i = 0; i < Nx[0]; ++i)
{
size_t j;
double sum = 0.0;
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i);
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
dtype_%(sm)s row_max = x_i[0];
//std::cout << "0 " << row_max << "\\n";
// Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "1 " << row_ij << "\\n";
row_max = (row_ij > row_max) ? row_ij : row_max;
}
"""
inside_row_loop = """
for (j = 0; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%(sm)s sm_ij = exp(row_ij - row_max);
//std::cout << "3 " << j << " " << sm_ij << "\\n";
sum += sm_ij;
sm_i[j * Ssm1] = sm_ij;
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
"""
# Get the vectorized version of exp if it exist
try:
vec_exp = aesara.scalar.exp.c_code_contiguous_raw(
dtype, "Nx[1]", "sm_i", "sm_i"
)
inside_row_loop_contig = (
"""
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] = x_i[j * Sx1] - row_max;
}
%(vec_exp)s;
for (j = 0; j < Nx[1]; ++j)
{
sum += sm_i[j * Ssm1];
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
"""
% locals()
)
inside_row_loop = (
"""
if(Ssm1 == 1){
%(inside_row_loop_contig)s
}else{
%(inside_row_loop)s
}
return dedent(
f"""
PyArrayObject* op[2];
npy_uint32 op_flags[2];
npy_uint32 iter_flags;
NpyIter* iter;
NpyIter_IterNextFunc* get_next;
char** data_ptr;
int x_ndim = PyArray_NDIM({x});
int axis = {axis};
int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float");
{fail}
}}
if (axis < 0) axis = x_ndim + axis;
if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
{{
PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
{fail}
}}
// Allocate Output Array
if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
{{
Py_XDECREF({sm});
{sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
if(!{sm}) {{
PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
{fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
% locals()
)
except aesara.graph.utils.MethodNotDefined:
pass
end_row_loop = """
}
"""
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop)
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
code_template = "".join(
self.c_code_template(node.inputs[0].type.dtype_specs()[1])
)
return code_template % dict(locals(), **sub)
@staticmethod
def c_code_cache_version():
return (3,)
return (4,)
softmax_op = Softmax()
softmax_legacy = Softmax(axis=-1)
class LogSoftmax(COp):
......@@ -689,7 +848,7 @@ class LogSoftmax(COp):
def grad(self, inp, grads):
(x,) = inp
sm = softmax_op(x)
sm = softmax_legacy(x)
return [grads[0] - aet_sum(grads[0], axis=1, keepdims=True) * sm]
def R_op(self, inputs, eval_points):
......@@ -816,7 +975,8 @@ def local_logsoftmax(fgraph, node):
and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax)
and node.inputs[0].owner.op == softmax_legacy
and node.inputs[0].ndim == 2
):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax()
......@@ -843,7 +1003,8 @@ def local_logsoftmax_grad(fgraph, node):
and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None
and node.inputs[0].owner.inputs[1].owner.op == softmax_op
and node.inputs[0].owner.inputs[1].owner.op == softmax_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not (
# skip if it will be optimized by
......@@ -871,13 +1032,26 @@ def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True)
def softmax(c):
UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c)
if c.broadcastable[-1]:
if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn(
"The softmax is applied on a dimension of shape 1, which does not have a semantic meaning."
"Softmax no longer converts a vector to a row matrix.",
UserWarning,
)
return softmax_op(c)
return Softmax(axis=axis)(c)
def logsoftmax(c):
......@@ -885,13 +1059,13 @@ def logsoftmax(c):
@register_specialize("fast_compile_gpu")
@local_optimizer([softmax_op])
@local_optimizer([softmax_legacy])
def local_softmax_with_bias(fgraph, node):
"""
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
"""
if node.op == softmax_op:
if node.op == softmax_legacy and node.outputs[0].ndim == 2:
(x,) = node.inputs
if x.owner and x.owner.op == add:
vectors = []
......@@ -980,7 +1154,7 @@ def softmax_simplifier(numerators, denominators):
matching_denom = denominator
break
if matching_denom:
softmax = softmax_op(x)
softmax = softmax_legacy(x)
copy_stack_trace(numerator, softmax)
numerators.remove(numerator)
denominators.remove(matching_denom)
......@@ -1611,7 +1785,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
if node.op == crossentropy_categorical_1hot:
(nll,) = node.outputs
sm, one_of_n = node.inputs
if sm.owner and sm.owner.op == softmax_op:
if sm.owner and sm.owner.op == softmax_legacy and sm.ndim == 2:
(x,) = sm.owner.inputs
(
new_nll,
......@@ -1658,9 +1832,9 @@ optdb.register(
@register_specialize(
"fast_compile_gpu", "local_crossentropy_to_crossentropy_with_softmax_grad"
) # old name
@local_optimizer([softmax_grad])
@local_optimizer([softmax_grad_legacy])
def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
if node.op == softmax_grad:
if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2:
g_coding_dist, coding_dist = node.inputs
if (
g_coding_dist.owner
......@@ -1686,13 +1860,16 @@ def local_argmax_pushdown(fgraph, node):
x = node.inputs[0]
axis = node.op.get_params(node)
# TODO: Make a list/set of monotonic ops...
if x.owner and x.owner.op in (
softmax_op,
softplus,
exp,
log,
tanh,
sigmoid,
if x.owner and (
x.owner.op
in (
softplus,
exp,
log,
tanh,
sigmoid,
)
or isinstance(x.owner.op, Softmax)
):
(pre_x,) = x.owner.inputs
ret = max_and_argmax(pre_x, axis)
......@@ -1786,7 +1963,12 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
except Exception:
pass
if sm is not None and sm.owner and sm.owner.op in (softmax_op, softmax_with_bias):
if (
sm is not None
and sm.owner
and sm.owner.op in (softmax_legacy, softmax_with_bias)
and sm.ndim == 2
):
sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner)
if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias
......@@ -1807,9 +1989,9 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize("fast_compile_gpu")
@local_optimizer([softmax_grad])
@local_optimizer([softmax_grad_legacy])
def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if not (node.op == softmax_grad):
if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2):
return
sm = None
......@@ -1821,7 +2003,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if (
(sm is not None)
and sm.owner
and (sm.owner.op in (softmax_op, softmax_with_bias))
and (sm.owner.op in (softmax_legacy, softmax_with_bias))
and sm.ndim == 2
):
sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner)
if sm_w_bias:
......
......@@ -32,7 +32,7 @@ from aesara.tensor.math import (
sqrt,
)
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_op
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_legacy
from aesara.tensor.nnet.abstract_conv import (
get_conv_gradinputs_shape,
get_conv_output_shape,
......@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
def test_softmax_f16(self):
x = matrix("x", "float16")
x_gpu = tensor4("x_gpu", "float16")
f_z = softmax_op
f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
def cmp(n, m, f, f_gpu):
......@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
x = matrix("x")
x_gpu = tensor4("x_gpu")
f_z = softmax_op
f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
# Verify the grad operation
......
......@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput):
z = aesara.tensor.nnet.softmax(x)
f = aesara.function([x], z, mode=mode_without_gpu)
f_gpu = aesara.function([x], z, mode=mode_wo_cudnn)
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_op
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_legacy
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax)
def cmp(n, m):
......@@ -300,7 +300,7 @@ class TestSoftMax:
def test_softmax(self):
x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op
z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
......@@ -308,7 +308,7 @@ class TestSoftMax:
def test_softmax_shape_0(self):
x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op
z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
# Aesara can handle that case, but cudnn can't
......
......@@ -969,11 +969,16 @@ def test_nnet():
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softmax(x)
out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc):
@pytest.mark.parametrize(
"x, exc",
"x, axis, exc",
[
(
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
],
)
def test_Softmax(x, exc):
g = nnetb.Softmax()(x)
def test_Softmax(x, axis, exc):
g = nnetb.Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
from contextlib import ExitStack as does_not_raise
import numpy as np
import pytest
import scipy.special as sp
import aesara
import aesara.tensor as aet
......@@ -49,13 +52,13 @@ from aesara.tensor.nnet.basic import (
selu,
sigmoid_binary_crossentropy,
softmax,
softmax_grad,
softmax_grad_legacy,
softmax_graph,
softmax_op,
softmax_legacy,
softmax_with_bias,
softsign,
)
from aesara.tensor.shape import shape_padleft, specify_shape
from aesara.tensor.shape import shape_padleft
from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import (
dmatrix,
......@@ -67,6 +70,7 @@ from aesara.tensor.type import (
matrices,
matrix,
scalar,
tensor3,
tensor4,
vector,
vectors,
......@@ -80,46 +84,64 @@ from tests.tensor.utils import (
)
class TestSoftmax(utt.InferShapeTester):
def test_basic(self):
def f(a):
return softmax_op(a)[:, 0]
def valid_axis_tester(Op):
with pytest.raises(TypeError):
Op(1.5)
utt.verify_grad(f, [np.random.random((3, 4))])
x = [tensor3()] * Op.nin
with does_not_raise():
Op(2)(*x)
def f(a):
return softmax_op(a)[:, 1]
with pytest.raises(ValueError):
Op(3)(*x)
utt.verify_grad(f, [np.random.random((3, 4))])
with does_not_raise():
Op(-3)(*x)
def f(a):
return softmax_op(a)[:, 2]
with pytest.raises(ValueError):
Op(-4)(*x)
utt.verify_grad(f, [np.random.random((3, 4))])
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
xv = np.random.randn(2, 3, 4, 5).astype(config.floatX)
f = aesara.function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), sp.softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a):
return softmax_op(a)[:, 3]
return softmax(a, axis=axis)[:, column]
utt.verify_grad(f, [np.random.random((3, 4))])
utt.verify_grad(f, [np.random.random((3, 4, 2))])
def test_infer_shape(self):
admat = matrix()
admat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check([admat], [Softmax()(admat)], [admat_val], Softmax)
self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector(self):
def test_vector_perform(self):
x = vector()
f = aesara.function([x], softmax_op(x))
f = aesara.function([x], softmax(x, axis=None))
xv = np.random.randn(6).astype(config.floatX)
assert np.allclose(f(xv), np.exp(xv) / np.exp(xv).sum())
assert np.allclose(f(xv), sp.softmax(xv))
def test_vector_grad(self):
def f(a):
return softmax_op(a)
return softmax(a, axis=None)
utt.verify_grad(f, [np.random.random((4))])
def test_valid_axis(self):
valid_axis_tester(Softmax)
class TestSoftmaxWithBias(utt.InferShapeTester):
def test_basic(self):
......@@ -154,10 +176,10 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
W = aesara.shared(value=initial_W, name="W")
vbias = aesara.shared(value=0.1, name="vbias") # 0.01
hid = vector("hid")
f = aesara.function([hid], softmax_op(dot(hid, W.T) + vbias))
f = aesara.function([hid], softmax_legacy(dot(hid, W.T) + vbias))
ops = [node.op for node in f.maker.fgraph.toposort()]
assert softmax_with_bias not in ops
assert softmax_op in ops
assert softmax_legacy in ops
f([0, 1, 0])
# print f.maker.fgraph.toposort()
......@@ -315,17 +337,19 @@ class TestLogSoftmax(utt.InferShapeTester):
g = grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == softmax_grad
assert softmax_grad_node.op == softmax_grad_legacy
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = softmax_grad(add(*true_div_node.inputs), softmax_grad_node.inputs[1])
new_g = softmax_grad_legacy(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert softmax_grad in [n.op for n in fgraph.toposort()]
assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
class TestSoftmaxGrad(utt.InferShapeTester):
......@@ -336,11 +360,14 @@ class TestSoftmaxGrad(utt.InferShapeTester):
bdmat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check(
[admat, bdmat],
[SoftmaxGrad()(admat, bdmat)],
[SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val],
SoftmaxGrad,
)
def test_valid_axis(self):
valid_axis_tester(SoftmaxGrad)
class TestCrossEntropySoftmax1Hot:
def test_basic(self):
......@@ -611,17 +638,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot
# xe = op(x, one_of_n)
fgraph = FunctionGraph([x, one_of_n], [op(softmax_op(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_vector(self):
x = vector("x")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, one_of_n], [op(softmax_op(x), one_of_n)])
fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......@@ -633,7 +650,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_op(x + b), one_of_n)])
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......@@ -649,7 +666,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot
fgraph = FunctionGraph(
[x, b, c, one_of_n], [op(softmax_op(add(x, b, c)), one_of_n)]
[x, b, c, one_of_n], [op(softmax_legacy(add(x, b, c)), one_of_n)]
)
assert fgraph.outputs[0].owner.op == op
......@@ -658,29 +675,17 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias_vector(self):
x = vector("x")
b = vector("b")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_op(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_grad_optimizations(self):
x = matrix("x")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
xe = op(softmax_op(x), one_of_n)
xe = op(softmax_legacy(x), one_of_n)
sum_xe = aet_sum(xe)
g_x = grad(sum_xe, x)
fgraph = FunctionGraph([x, one_of_n], [g_x])
assert check_stack_trace(
fgraph, ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_op]
fgraph,
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
......@@ -688,25 +693,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
def test_softmax_grad_optimizations_vector(self):
x = vector("x")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
xe = op(softmax_op(x), one_of_n)
sum_xe = aet_sum(xe)
g_x = grad(sum_xe, x)
fgraph = FunctionGraph([x, one_of_n], [g_x])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
def test_get_rid_of_advanced_indexing_version_of_xent(self):
x = matrix("x")
......@@ -737,8 +725,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
# Test that a biased softmax is optimized correctly
bias_expressions = [
......@@ -763,7 +751,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(ops) == 2
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
assert softmax_grad_legacy not in ops
# Test that using "mean" instead of sum works, too
mean_expressions = [
......@@ -791,8 +779,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
mean_bias_expressions = [
mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
......@@ -818,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(ops) == 5
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
assert softmax_grad_legacy not in ops
def test_xent_thing_int32(self):
x = matrix("x")
......@@ -847,141 +835,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
def test_optimize_xent_vector(self):
x = vector("x")
y = lvector("y")
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
def test_optimize_xent_vector2(self):
x = vector("x")
b = vector("b")
y = lvector("y")
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle]
assert len(ops) == 3
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
def test_optimize_xent_vector3(self):
# Same as test_optimize_xent_vector2, but y is the result of
# a "flatten", and it used to make the constant-folding
# of arange(y.shape[0]) happen before the xent optimization
x = vector("x")
b = vector("b")
y_ = lvector("y_")
y = y_.flatten()
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y_], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle, flatten]
assert len(ops) <= 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
def test_optimize_xent_vector4(self):
# Same as test_optimize_xent_vector2, but y is the result of a
# "specify_shape" that indicates its length is 1, so the
# constant-folding of arange(y.shape[0]) happen before the xent
# optimization
x = vector("x")
b = vector("b")
y_ = lvector("y_")
y = specify_shape(y_, (1,))
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y_], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle, specify_shape]
assert len(ops) <= 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
x = matrix("x")
......@@ -996,9 +851,9 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for node in func.maker.fgraph.toposort():
if node.op == crossentropy_softmax_1hot_with_bias_dx:
has_cx1hotdx = True
if node.op == softmax_op:
if node.op == softmax_legacy:
has_softmax = True
if node.op == softmax_grad:
if node.op == softmax_grad_legacy:
has_softmaxdx = True
assert has_cx1hotdx
......@@ -1033,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert softmax_op not in ops
assert softmax_legacy not in ops
# Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
......@@ -1043,8 +898,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
# Verify the gradient when providing output gradient
fgraph = FunctionGraph(
......@@ -1056,13 +911,13 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
assert softmax_legacy in ops
assert softmax_grad_legacy not in ops
def test_argmax_pushdown():
x = matrix()
for sm in [softmax_graph, softmax_op]:
for sm in [softmax_graph, softmax_legacy]:
# test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out])
......@@ -1188,12 +1043,12 @@ class TestSoftmaxOpt:
# test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_op)
assert check_stack_trace(f, ops_to_check=softmax_legacy)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert softmax_op in f_ops
assert softmax_legacy in f_ops
f(self.rng.random((3, 4)).astype(config.floatX))
......@@ -1204,12 +1059,12 @@ class TestSoftmaxOpt:
# test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_op)
assert check_stack_trace(f, ops_to_check=softmax_legacy)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert softmax_op in f_ops
assert softmax_legacy in f_ops
f(self.rng.random((3, 4)).astype(config.floatX))
......@@ -1226,8 +1081,8 @@ class TestSoftmaxOpt:
g_ops = [n.op for n in g.maker.fgraph.toposort()]
assert len(g_ops) == 2
assert softmax_op in g_ops
assert softmax_grad in g_ops
assert softmax_legacy in g_ops
assert softmax_grad_legacy in g_ops
g(self.rng.random((3, 4)), self.rng.uniform(0.5, 1, (3, 4)))
......@@ -1272,7 +1127,7 @@ def test_grad_softmax_grad():
x = aesara.shared(rng.normal(size=(3, 4)))
def f(inputs):
y = softmax_op(x)
y = softmax_legacy(x)
return aesara.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.random((3, 4))])
......
......@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self):
# Softmax adds an extra dimnesion !
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x)[0], self.in_shape[0])
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape)
def test_alloc(self):
# Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论