提交 58cb5c30 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add axis to Softmax and SoftmaxGrad Ops

上级 c6c85acb
...@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs): ...@@ -198,8 +198,10 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register(Softmax) @jax_funcify.register(Softmax)
def jax_funcify_Softmax(op, **kwargs): def jax_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x): def softmax(x):
return jax.nn.softmax(x) return jax.nn.softmax(x, axis=axis)
return softmax return softmax
......
...@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -400,17 +400,24 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
# np.max(x, axis=1) if axis is not None:
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype) reduce_max = create_axis_reducer(
# np.sum(x, axis=1) np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype) )
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba.njit @numba.njit
def softmax(x): def softmax(x):
z = np.expand_dims(reduce_max(x), -1) z = reduce_max(x)
e_x = np.exp(x - z) e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1) w = reduce_sum(e_x)
sm = e_x / w sm = e_x / w
return sm return sm
......
...@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import ( ...@@ -35,9 +35,9 @@ from aesara.tensor.nnet.basic import (
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
softmax, softmax,
softmax_grad, softmax_grad_legacy,
softmax_graph, softmax_graph,
softmax_op, softmax_legacy,
softmax_simplifier, softmax_simplifier,
softmax_with_bias, softmax_with_bias,
softsign, softsign,
......
...@@ -14,8 +14,10 @@ revisited later when all the intermediate part are on the GPU. ...@@ -14,8 +14,10 @@ revisited later when all the intermediate part are on the GPU.
""" """
import warnings import warnings
from textwrap import dedent
import numpy as np import numpy as np
import scipy.special
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
...@@ -140,7 +142,7 @@ class SoftmaxWithBias(COp): ...@@ -140,7 +142,7 @@ class SoftmaxWithBias(COp):
if isinstance(g_sm.type, DisconnectedType): if isinstance(g_sm.type, DisconnectedType):
return [DisconnectedType()(), DisconnectedType()()] return [DisconnectedType()(), DisconnectedType()()]
dx = softmax_grad(g_sm, outputs[0]) dx = softmax_grad_legacy(g_sm, outputs[0])
db = aet_sum(dx, axis=0) db = aet_sum(dx, axis=0)
return dx, db return dx, db
...@@ -339,36 +341,39 @@ class SoftmaxGrad(COp): ...@@ -339,36 +341,39 @@ class SoftmaxGrad(COp):
nin = 2 nin = 2
nout = 1 nout = 1
__props__ = () __props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, dy, sm): def make_node(self, dy, sm):
dy = aet.as_tensor_variable(dy) dy = aet.as_tensor_variable(dy)
sm = aet.as_tensor_variable(sm) sm = aet.as_tensor_variable(sm)
if dy.type.ndim not in (1, 2) or dy.type.dtype not in float_dtypes:
raise ValueError("dy must be 1-d or 2-d tensor of floats. Got ", dy.type) if self.axis is not None and (self.axis >= sm.ndim or self.axis < -sm.ndim):
if dy.ndim == 1: raise ValueError(
dy = shape_padleft(dy, n_ones=1) f"SoftmaxGrad axis(={self.axis}) out of bounds for {sm.ndim}D array {sm}"
if sm.ndim == 1: )
sm = shape_padleft(sm, n_ones=1)
return Apply(self, [dy, sm], [sm.type()]) return Apply(self, [dy, sm], [sm.type()])
def perform(self, node, input_storage, output_storage): def perform(self, node, input_storage, output_storage):
dy, sm = input_storage dy, sm = input_storage
dx = np.zeros_like(sm)
# dx[i,j] = - (\sum_k dy[i,k] sm[i,k]) sm[i,j] + dy[i,j] sm[i,j] dy_times_sm = dy * sm
for i in range(sm.shape[0]): dx = dy_times_sm - np.sum(dy_times_sm, axis=self.axis, keepdims=True) * sm
dy_times_sm_i = dy[i] * sm[i]
dx[i] = dy_times_sm_i - sum(dy_times_sm_i) * sm[i]
output_storage[0][0] = dx output_storage[0][0] = dx
def grad(self, inp, grads): def grad(self, inp, grads):
dy, sm = inp dy, sm = inp
(g,) = grads (g,) = grads
tmp = g + neg(aet_sum(g * sm, axis=1).dimshuffle((0, "x"))) tmp = g + neg(aet_sum(g * sm, axis=self.axis, keepdims=True))
g_dy = tmp * sm g_dy = tmp * sm
tmp2 = aet_sum(dy * sm, axis=1).dimshuffle((0, "x")) tmp2 = aet_sum(dy * sm, axis=self.axis, keepdims=True)
g_sm = tmp * dy - g * tmp2 g_sm = tmp * dy - g * tmp2
return g_dy, g_sm return g_dy, g_sm
...@@ -377,79 +382,184 @@ class SoftmaxGrad(COp): ...@@ -377,79 +382,184 @@ class SoftmaxGrad(COp):
return [shape[1]] return [shape[1]]
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
dy, sm = inp dy, sm = inp
(dx,) = out (dx,) = out
return """ axis = self.axis if self.axis is not None else np.MAXDIMS
if ((PyArray_TYPE(%(dy)s) != NPY_DOUBLE) && fail = sub["fail"]
(PyArray_TYPE(%(dy)s) != NPY_FLOAT))
{ return dedent(
PyErr_SetString(PyExc_TypeError, f"""
"types should be float or float64"); PyArrayObject* op[3];
%(fail)s; npy_uint32 op_flags[3];
} npy_uint32 iter_flags;
if ((PyArray_TYPE(%(sm)s) != NPY_DOUBLE) && NpyIter* iter;
(PyArray_TYPE(%(sm)s) != NPY_FLOAT)) NpyIter_IterNextFunc* get_next;
{ char** data_ptr;
PyErr_SetString(PyExc_TypeError,
"types should be float or float64"); int sm_ndim = PyArray_NDIM({sm});
%(fail)s; int axis = {axis};
} int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1);
if ((PyArray_NDIM(%(dy)s) != 2)
|| (PyArray_NDIM(%(sm)s) != 2)) // Validate inputs
{ if ((PyArray_TYPE({dy}) != NPY_DOUBLE) &&
PyErr_SetString(PyExc_ValueError, "rank error"); (PyArray_TYPE({dy}) != NPY_FLOAT))
%(fail)s; {{
} PyErr_SetString(PyExc_TypeError, "types should be float or float64");
if (PyArray_DIMS(%(dy)s)[0] != PyArray_DIMS(%(sm)s)[0]) {fail};
{ }}
PyErr_SetString(PyExc_ValueError, "dy.shape[0] != sm.shape[0]"); if ((PyArray_TYPE({sm}) != NPY_DOUBLE) &&
%(fail)s; (PyArray_TYPE({sm}) != NPY_FLOAT))
} {{
if ((NULL == %(dx)s) PyErr_SetString(PyExc_TypeError, "types should be float or float64");
|| (PyArray_DIMS(%(dx)s)[0] != PyArray_DIMS(%(sm)s)[0]) {fail};
|| (PyArray_DIMS(%(dx)s)[1] != PyArray_DIMS(%(sm)s)[1])) }}
{
Py_XDECREF(%(dx)s); if (axis < 0) axis = sm_ndim + axis;
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2, if ((axis < 0) || (iterate_axis && (axis > sm_ndim)))
PyArray_DIMS(%(sm)s), {{
PyArray_TYPE(%(sm)s)); PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad");
if (!%(dx)s) {fail};
{ }}
PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output"); if (({dx} == NULL)
%(fail)s; || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim)))
} {{
} Py_XDECREF({dx});
{dx} = (PyArrayObject*)PyArray_SimpleNew(sm_ndim,
for (size_t i = 0; i < PyArray_DIMS(%(dx)s)[0]; ++i) PyArray_DIMS({sm}),
{ PyArray_TYPE({sm}));
const dtype_%(dy)s* __restrict__ dy_i = (dtype_%(dy)s*) (PyArray_BYTES(%(dy)s) + PyArray_STRIDES(%(dy)s)[0] * i); if (!{dx})
npy_intp Sdy = PyArray_STRIDES(%(dy)s)[1]/sizeof(dtype_%(dy)s); {{
const dtype_%(sm)s* __restrict__ sm_i = (dtype_%(sm)s*) (PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i); PyErr_SetString(PyExc_MemoryError, "failed to alloc SoftMaxGrad dx output");
npy_intp Ssm = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s); {fail};
dtype_%(dx) s* __restrict__ dx_i = (dtype_%(dx)s*) (PyArray_BYTES(%(dx)s) + PyArray_STRIDES(%(dx)s)[0] * i); }}
npy_intp Sdx = PyArray_STRIDES(%(dx)s)[1]/sizeof(dtype_%(dx)s); }}
double sum_dy_times_sm = 0.; // Create numpy iterator
for (size_t j = 0; j < PyArray_DIMS(%(dx)s)[1]; ++j) op[0] = {dy};
{ op[1] = {sm};
dx_i[j * Sdx] = dy_i[j * Sdy] * sm_i[j * Ssm]; op[2] = {dx};
sum_dy_times_sm += dx_i[j * Sdx]; op_flags[0] = NPY_ITER_READONLY;
} op_flags[1] = NPY_ITER_READONLY;
for (size_t j = 0; j < PyArray_DIMS(%(dx)s)[1]; ++j) op_flags[2] = NPY_ITER_READWRITE;
{ iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
dx_i[j * Sdx] -= sum_dy_times_sm * sm_i[j * Ssm]; iter = NpyIter_MultiNew(
} 3,
} op,
""" % dict( iter_flags,
locals(), **sub NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create softmax iterator");
{fail};
}}
// SoftmaxGrad is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftMaxGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
do
{{
dtype_{dy}* dy_ptr = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr = (dtype_{dx})((*dy_ptr) * (*sm_ptr));
sum_dy_times_sm += *dx_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset softmax iterator");
{fail};
}}
// Subtract sum(dy*sm) * sm
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_ptr = (dtype_{dx}*)data_ptr[2];
*dx_ptr -= sum_dy_times_sm * ((dtype_{dx})(*sm_ptr));
}} while(get_next(iter));
}}
// SoftmaxGrad is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({sm}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax axis strides");
{fail};
}}
npy_intp dy_axis_stride = axis_stride[0] / sizeof(dtype_{dy});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
npy_intp dx_axis_stride = axis_stride[2] / sizeof(dtype_{dx});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove SoftmaxGrad axis from iterator");
{fail};
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain SoftamGrad GetIterNext");
{fail};
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{dy}* dy_axis = (dtype_{dy}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
dtype_{dx}* dx_axis = (dtype_{dx}*)data_ptr[2];
// Compute and accumulate dy * sm
dtype_{dx} sum_dy_times_sm = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] = (dtype_{dx})(dy_axis[i * dy_axis_stride] * sm_axis[i * sm_axis_stride]);
sum_dy_times_sm += dx_axis[i * dx_axis_stride];
}}
// Subtract sum(dy*sm) * sm
for (npy_intp i = 0; i < axis_size; i++)
{{
dx_axis[i * dx_axis_stride] -= sum_dy_times_sm * (dtype_{dx})(sm_axis[i * sm_axis_stride]);
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
"""
) )
softmax_grad = SoftmaxGrad() softmax_grad_legacy = SoftmaxGrad(axis=-1)
class Softmax(COp): class Softmax(COp):
...@@ -464,34 +574,32 @@ class Softmax(COp): ...@@ -464,34 +574,32 @@ class Softmax(COp):
nin = 1 nin = 1
nout = 1 nout = 1
__props__ = () __props__ = ("axis",)
def __init__(self, axis):
if axis is not None and not isinstance(axis, int):
raise TypeError("axis must be an integer or `None`")
self.axis = axis
def make_node(self, x): def make_node(self, x):
x = aet.as_tensor_variable(x) x = aet.as_tensor_variable(x)
if x.type.ndim not in (1, 2) or x.type.dtype not in float_dtypes:
raise ValueError(f"x must be 1-d or 2-d tensor of floats. Got {x.type}") if self.axis is not None and (self.axis >= x.ndim or self.axis < -x.ndim):
if x.ndim == 1: raise ValueError(
warnings.warn( f"Softmax axis(={self.axis}) out of bounds for {x.ndim}D array {x}"
"If x is a vector, Softmax will not automatically pad x "
"anymore in next releases. If you need it, please do it manually. The "
"vector case is gonna be supported soon and the output will be a vector.",
category=PendingDeprecationWarning,
stacklevel=4,
) )
x = shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, input_storage, output_storage): def perform(self, node, input_storage, output_storage):
(x,) = input_storage (x,) = input_storage
e_x = np.exp(x - x.max(axis=1)[:, None]) (z,) = output_storage
sm = e_x / e_x.sum(axis=1)[:, None] z[0] = scipy.special.softmax(x, axis=self.axis)
output_storage[0][0] = sm
def L_op(self, inp, outputs, grads): def L_op(self, inp, outputs, grads):
(x,) = inp (x,) = inp
(g_sm,) = grads (g_sm,) = grads
return [softmax_grad(g_sm, outputs[0])] return [SoftmaxGrad(axis=self.axis)(g_sm, outputs[0])]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# I think the Jacobian is symmetric so the R_op # I think the Jacobian is symmetric so the R_op
...@@ -506,151 +614,202 @@ class Softmax(COp): ...@@ -506,151 +614,202 @@ class Softmax(COp):
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
return ["<iostream>", "<cmath>"] return ["<iostream>", "<cmath>"]
@staticmethod def c_code(self, node, name, inp, out, sub):
def c_code_template(dtype): (x,) = inp
# this implementation was lifted from (sm,) = out
# /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx axis = self.axis if self.axis is not None else np.MAXDIMS
fail = sub["fail"]
# dtype = node.inputs[0].type.dtype_specs()[1]
# TODO: put this into a templated function, in the support code # TODO: put this into a templated function, in the support code
# TODO: declare the max of each row as an Op output # TODO: declare the max of each row as an Op output
# TODO: set error messages for failures in this code
# TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] # TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
init_decl = """ return dedent(
npy_intp* Nx = PyArray_DIMS(%(x)s); f"""
npy_intp Sx1 = 0; PyArrayObject* op[2];
npy_intp Ssm1 = 0; npy_uint32 op_flags[2];
npy_uint32 iter_flags;
if (PyArray_NDIM(%(x)s) != 2) NpyIter* iter;
{ NpyIter_IterNextFunc* get_next;
PyErr_SetString(PyExc_ValueError, "not a 2d tensor"); char** data_ptr;
%(fail)s;
} int x_ndim = PyArray_NDIM({x});
if ((PyArray_TYPE(%(x)s) != NPY_DOUBLE) && int axis = {axis};
(PyArray_TYPE(%(x)s) != NPY_FLOAT)) int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1);
{
// Validate inputs
if ((PyArray_TYPE({x}) != NPY_DOUBLE) &&
(PyArray_TYPE({x}) != NPY_FLOAT))
{{
PyErr_SetString(PyExc_TypeError, "not a float"); PyErr_SetString(PyExc_TypeError, "not a float");
%(fail)s; {fail}
} }}
if ((NULL == %(sm)s) if (axis < 0) axis = x_ndim + axis;
|| (PyArray_DIMS(%(sm)s)[0] != PyArray_DIMS(%(x)s)[0]) if ((axis < 0) || (iterate_axis && (axis > x_ndim)))
|| (PyArray_DIMS(%(sm)s)[1] != PyArray_DIMS(%(x)s)[1])) {{
{ PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax");
Py_XDECREF(%(sm)s); {fail}
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), }}
PyArray_TYPE(%(x)s));
if(!%(sm)s) { // Allocate Output Array
PyErr_SetString(PyExc_MemoryError, if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim)))
"failed to alloc sm output"); {{
%(fail)s Py_XDECREF({sm});
} {sm} = (PyArrayObject*)PyArray_SimpleNew(x_ndim, PyArray_DIMS({x}), PyArray_TYPE({x}));
} if(!{sm}) {{
Sx1 = PyArray_STRIDES(%(x)s)[1]/sizeof(dtype_%(x)s); PyErr_SetString(PyExc_MemoryError, "failed to alloc Softmax output");
Ssm1 = PyArray_STRIDES(%(sm)s)[1]/sizeof(dtype_%(sm)s); {fail}
}}
}}
// Create numpy iterator
op[0] = {x};
op[1] = {sm};
op_flags[0] = NPY_ITER_READONLY;
op_flags[1] = NPY_ITER_READWRITE;
iter_flags = (iterate_axis)? NPY_ITER_MULTI_INDEX : 0;
iter = NpyIter_MultiNew(
2,
op,
iter_flags,
NPY_KEEPORDER,
NPY_NO_CASTING,
op_flags,
NULL
);
if (iter == NULL)
{{
PyErr_SetString(PyExc_MemoryError, "failed to create Softmax iterator");
{fail}
}}
// Softmax is applied across the entire array
if (!iterate_axis)
{{
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
// Find axis max
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{x} max = *x_ptr;
if (get_next(iter))
{{
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
max = (*x_ptr > max)? *x_ptr : max;
}} while(get_next(iter));
}}
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Compute and accumulate exp(x-max(x)) exponent
double sum_exp_dev = 0.0;
do
{{
dtype_{x}* x_ptr = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr = (dtype_{sm}) exp(*x_ptr - max);
sum_exp_dev += *sm_ptr;
}} while(get_next(iter));
// Reset Iterator
if (NpyIter_GotoIterIndex(iter, 0) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to reset Softmax iterator");
{fail}
}}
// Divide by sum(exp(x-max(x)))
double inv_sum_exp_dev = 1.0 / sum_exp_dev;
do
{{
dtype_{sm}* sm_ptr = (dtype_{sm}*)data_ptr[1];
*sm_ptr *= inv_sum_exp_dev;
}} while(get_next(iter));
}}
// Softmax is applied across a specific axis
else {{
// Collect axis strides and remove it from iteration
npy_intp axis_size = PyArray_DIM({x}, axis);
npy_intp* axis_stride = NpyIter_GetAxisStrideArray(iter, axis);
if (axis_stride == NULL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain Softmax axis strides");
{fail}
}}
npy_intp x_axis_stride = axis_stride[0] / sizeof(dtype_{x});
npy_intp sm_axis_stride = axis_stride[1] / sizeof(dtype_{sm});
if (NpyIter_RemoveAxis(iter, axis) == NPY_FAIL)
{{
PyErr_SetString(PyExc_RuntimeError, "Failed to remove softmax axis from iterator");
{fail}
}}
// Iterate over remaining axes
get_next = NpyIter_GetIterNext(iter, NULL);
if (get_next == NULL)
{{
NpyIter_Deallocate(iter);
PyErr_SetString(PyExc_RuntimeError, "Failed to obtain softmax GetIterNext");
{fail}
}}
data_ptr = NpyIter_GetDataPtrArray(iter);
do
{{
dtype_{x}* x_axis = (dtype_{x}*)data_ptr[0];
dtype_{sm}* sm_axis = (dtype_{sm}*)data_ptr[1];
// Find axis max
dtype_{x} max = x_axis[0];
for (npy_intp i = 1; i < axis_size; i++)
{{
dtype_{x} x_val = x_axis[i * x_axis_stride];
max = (x_val > max)? x_val : max;
}}
// Compute and accumulate exp(x-max(x)) exponent
dtype_{sm} sum_exp_dev = 0.0;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] = (dtype_{sm}) exp(x_axis[i * x_axis_stride] - max);
sum_exp_dev += sm_axis[i * sm_axis_stride];
}}
// Divide by sum(exp(x-max(x)))
dtype_{sm} inv_sum_exp_dev = 1.0 / sum_exp_dev;
for (npy_intp i = 0; i < axis_size; i++)
{{
sm_axis[i * sm_axis_stride] *= inv_sum_exp_dev;
}}
}} while(get_next(iter));
}}
NpyIter_Deallocate(iter);
""" """
begin_row_loop = """
for (size_t i = 0; i < Nx[0]; ++i)
{
size_t j;
double sum = 0.0;
const dtype_%(x)s* __restrict__ x_i = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * i);
dtype_%(sm) s* __restrict__ sm_i = (dtype_%(sm)s*)(PyArray_BYTES(%(sm)s) + PyArray_STRIDES(%(sm)s)[0] * i);
dtype_%(sm)s row_max = x_i[0];
//std::cout << "0 " << row_max << "\\n";
// Get the maximum value of the row
for (j = 1; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "1 " << row_ij << "\\n";
row_max = (row_ij > row_max) ? row_ij : row_max;
}
"""
inside_row_loop = """
for (j = 0; j < Nx[1]; ++j)
{
dtype_%(sm)s row_ij = x_i[j * Sx1] ;
//std::cout << "2 " << j << " " << row_ij << " " << row_max << "\\n";
dtype_%(sm)s sm_ij = exp(row_ij - row_max);
//std::cout << "3 " << j << " " << sm_ij << "\\n";
sum += sm_ij;
sm_i[j * Ssm1] = sm_ij;
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
"""
# Get the vectorized version of exp if it exist
try:
vec_exp = aesara.scalar.exp.c_code_contiguous_raw(
dtype, "Nx[1]", "sm_i", "sm_i"
) )
inside_row_loop_contig = (
"""
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] = x_i[j * Sx1] - row_max;
}
%(vec_exp)s;
for (j = 0; j < Nx[1]; ++j)
{
sum += sm_i[j * Ssm1];
}
//cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
double sum_inv = 1.0 / sum;
for (j = 0; j < Nx[1]; ++j)
{
sm_i[j * Ssm1] *= sum_inv;
}
"""
% locals()
)
inside_row_loop = (
"""
if(Ssm1 == 1){
%(inside_row_loop_contig)s
}else{
%(inside_row_loop)s
}
"""
% locals()
)
except aesara.graph.utils.MethodNotDefined:
pass
end_row_loop = """
}
"""
return (init_decl, begin_row_loop, inside_row_loop, end_row_loop)
def c_code(self, node, name, inp, out, sub):
(x,) = inp
(sm,) = out
code_template = "".join(
self.c_code_template(node.inputs[0].type.dtype_specs()[1])
)
return code_template % dict(locals(), **sub)
@staticmethod @staticmethod
def c_code_cache_version(): def c_code_cache_version():
return (3,) return (4,)
softmax_op = Softmax() softmax_legacy = Softmax(axis=-1)
class LogSoftmax(COp): class LogSoftmax(COp):
...@@ -689,7 +848,7 @@ class LogSoftmax(COp): ...@@ -689,7 +848,7 @@ class LogSoftmax(COp):
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
sm = softmax_op(x) sm = softmax_legacy(x)
return [grads[0] - aet_sum(grads[0], axis=1, keepdims=True) * sm] return [grads[0] - aet_sum(grads[0], axis=1, keepdims=True) * sm]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -816,7 +975,8 @@ def local_logsoftmax(fgraph, node): ...@@ -816,7 +975,8 @@ def local_logsoftmax(fgraph, node):
and isinstance(node.op.scalar_op, aes.Log) and isinstance(node.op.scalar_op, aes.Log)
and len(node.inputs) == 1 and len(node.inputs) == 1
and node.inputs[0].owner is not None and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax) and node.inputs[0].owner.op == softmax_legacy
and node.inputs[0].ndim == 2
): ):
inVars = node.inputs[0].owner.inputs[0] inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax() new_op = LogSoftmax()
...@@ -843,7 +1003,8 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -843,7 +1003,8 @@ def local_logsoftmax_grad(fgraph, node):
and node.inputs[0].owner.op == true_div and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2 and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None and node.inputs[0].owner.inputs[1].owner is not None
and node.inputs[0].owner.inputs[1].owner.op == softmax_op and node.inputs[0].owner.inputs[1].owner.op == softmax_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
and node.inputs[1] == node.inputs[0].owner.inputs[1] and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not ( and not (
# skip if it will be optimized by # skip if it will be optimized by
...@@ -871,13 +1032,26 @@ def softmax_graph(c): ...@@ -871,13 +1032,26 @@ def softmax_graph(c):
return exp(c) / exp(c).sum(axis=-1, keepdims=True) return exp(c) / exp(c).sum(axis=-1, keepdims=True)
def softmax(c): UNSET_AXIS = object()
def softmax(c, axis=UNSET_AXIS):
if axis is UNSET_AXIS:
warnings.warn(
"Softmax now accepts an axis argument. For backwards-compatibility it defaults to -1 when not specified, "
"but in the future the default will be `None`.\nTo suppress this warning specify axis explicitly.",
FutureWarning,
)
axis = -1
c = as_tensor_variable(c) c = as_tensor_variable(c)
if c.broadcastable[-1]: if c.ndim == 1:
# TODO: Create Specific warning type that can be suppressed?
warnings.warn( warnings.warn(
"The softmax is applied on a dimension of shape 1, which does not have a semantic meaning." "Softmax no longer converts a vector to a row matrix.",
UserWarning,
) )
return softmax_op(c) return Softmax(axis=axis)(c)
def logsoftmax(c): def logsoftmax(c):
...@@ -885,13 +1059,13 @@ def logsoftmax(c): ...@@ -885,13 +1059,13 @@ def logsoftmax(c):
@register_specialize("fast_compile_gpu") @register_specialize("fast_compile_gpu")
@local_optimizer([softmax_op]) @local_optimizer([softmax_legacy])
def local_softmax_with_bias(fgraph, node): def local_softmax_with_bias(fgraph, node):
""" """
Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias). Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias).
""" """
if node.op == softmax_op: if node.op == softmax_legacy and node.outputs[0].ndim == 2:
(x,) = node.inputs (x,) = node.inputs
if x.owner and x.owner.op == add: if x.owner and x.owner.op == add:
vectors = [] vectors = []
...@@ -980,7 +1154,7 @@ def softmax_simplifier(numerators, denominators): ...@@ -980,7 +1154,7 @@ def softmax_simplifier(numerators, denominators):
matching_denom = denominator matching_denom = denominator
break break
if matching_denom: if matching_denom:
softmax = softmax_op(x) softmax = softmax_legacy(x)
copy_stack_trace(numerator, softmax) copy_stack_trace(numerator, softmax)
numerators.remove(numerator) numerators.remove(numerator)
denominators.remove(matching_denom) denominators.remove(matching_denom)
...@@ -1611,7 +1785,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph): ...@@ -1611,7 +1785,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
if node.op == crossentropy_categorical_1hot: if node.op == crossentropy_categorical_1hot:
(nll,) = node.outputs (nll,) = node.outputs
sm, one_of_n = node.inputs sm, one_of_n = node.inputs
if sm.owner and sm.owner.op == softmax_op: if sm.owner and sm.owner.op == softmax_legacy and sm.ndim == 2:
(x,) = sm.owner.inputs (x,) = sm.owner.inputs
( (
new_nll, new_nll,
...@@ -1658,9 +1832,9 @@ optdb.register( ...@@ -1658,9 +1832,9 @@ optdb.register(
@register_specialize( @register_specialize(
"fast_compile_gpu", "local_crossentropy_to_crossentropy_with_softmax_grad" "fast_compile_gpu", "local_crossentropy_to_crossentropy_with_softmax_grad"
) # old name ) # old name
@local_optimizer([softmax_grad]) @local_optimizer([softmax_grad_legacy])
def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
if node.op == softmax_grad: if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2:
g_coding_dist, coding_dist = node.inputs g_coding_dist, coding_dist = node.inputs
if ( if (
g_coding_dist.owner g_coding_dist.owner
...@@ -1686,13 +1860,16 @@ def local_argmax_pushdown(fgraph, node): ...@@ -1686,13 +1860,16 @@ def local_argmax_pushdown(fgraph, node):
x = node.inputs[0] x = node.inputs[0]
axis = node.op.get_params(node) axis = node.op.get_params(node)
# TODO: Make a list/set of monotonic ops... # TODO: Make a list/set of monotonic ops...
if x.owner and x.owner.op in ( if x.owner and (
softmax_op, x.owner.op
in (
softplus, softplus,
exp, exp,
log, log,
tanh, tanh,
sigmoid, sigmoid,
)
or isinstance(x.owner.op, Softmax)
): ):
(pre_x,) = x.owner.inputs (pre_x,) = x.owner.inputs
ret = max_and_argmax(pre_x, axis) ret = max_and_argmax(pre_x, axis)
...@@ -1786,7 +1963,12 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node): ...@@ -1786,7 +1963,12 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
except Exception: except Exception:
pass pass
if sm is not None and sm.owner and sm.owner.op in (softmax_op, softmax_with_bias): if (
sm is not None
and sm.owner
and sm.owner.op in (softmax_legacy, softmax_with_bias)
and sm.ndim == 2
):
sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner) sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner)
if sm_w_bias: if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias assert sm_w_bias[0].owner.op == softmax_with_bias
...@@ -1807,9 +1989,9 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node): ...@@ -1807,9 +1989,9 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node):
@register_specialize("fast_compile_gpu") @register_specialize("fast_compile_gpu")
@local_optimizer([softmax_grad]) @local_optimizer([softmax_grad_legacy])
def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if not (node.op == softmax_grad): if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2):
return return
sm = None sm = None
...@@ -1821,7 +2003,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): ...@@ -1821,7 +2003,8 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node):
if ( if (
(sm is not None) (sm is not None)
and sm.owner and sm.owner
and (sm.owner.op in (softmax_op, softmax_with_bias)) and (sm.owner.op in (softmax_legacy, softmax_with_bias))
and sm.ndim == 2
): ):
sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner) sm_w_bias = local_softmax_with_bias.transform(fgraph, sm.owner)
if sm_w_bias: if sm_w_bias:
......
...@@ -32,7 +32,7 @@ from aesara.tensor.math import ( ...@@ -32,7 +32,7 @@ from aesara.tensor.math import (
sqrt, sqrt,
) )
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_op from aesara.tensor.nnet import batchnorm, conv2d, softmax, softmax_legacy
from aesara.tensor.nnet.abstract_conv import ( from aesara.tensor.nnet.abstract_conv import (
get_conv_gradinputs_shape, get_conv_gradinputs_shape,
get_conv_output_shape, get_conv_output_shape,
...@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax): ...@@ -1456,7 +1456,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
def test_softmax_f16(self): def test_softmax_f16(self):
x = matrix("x", "float16") x = matrix("x", "float16")
x_gpu = tensor4("x_gpu", "float16") x_gpu = tensor4("x_gpu", "float16")
f_z = softmax_op f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel") f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
def cmp(n, m, f, f_gpu): def cmp(n, m, f, f_gpu):
...@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax): ...@@ -1480,7 +1480,7 @@ class TestSoftMax(test_nnet.TestSoftMax):
x = matrix("x") x = matrix("x")
x_gpu = tensor4("x_gpu") x_gpu = tensor4("x_gpu")
f_z = softmax_op f_z = softmax_legacy
f_gpu = dnn.GpuDnnSoftmax("accurate", "channel") f_gpu = dnn.GpuDnnSoftmax("accurate", "channel")
# Verify the grad operation # Verify the grad operation
......
...@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput): ...@@ -210,7 +210,7 @@ def softmax_unittest_template(dtypeInput):
z = aesara.tensor.nnet.softmax(x) z = aesara.tensor.nnet.softmax(x)
f = aesara.function([x], z, mode=mode_without_gpu) f = aesara.function([x], z, mode=mode_without_gpu)
f_gpu = aesara.function([x], z, mode=mode_wo_cudnn) f_gpu = aesara.function([x], z, mode=mode_wo_cudnn)
assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_op assert f.maker.fgraph.toposort()[-1].op == aesara.tensor.nnet.softmax_legacy
assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax) assert isinstance(f_gpu.maker.fgraph.toposort()[-2].op, GpuSoftmax)
def cmp(n, m): def cmp(n, m):
...@@ -300,7 +300,7 @@ class TestSoftMax: ...@@ -300,7 +300,7 @@ class TestSoftMax:
def test_softmax(self): def test_softmax(self):
x = fmatrix("x") x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp) f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
...@@ -308,7 +308,7 @@ class TestSoftMax: ...@@ -308,7 +308,7 @@ class TestSoftMax:
def test_softmax_shape_0(self): def test_softmax_shape_0(self):
x = fmatrix("x") x = fmatrix("x")
z = aesara.tensor.nnet.softmax_op z = aesara.tensor.nnet.softmax_legacy
f, f_gpu = self._test_softmax(x, x, z, z, self._cmp) f, f_gpu = self._test_softmax(x, x, z, z, self._cmp)
# Aesara can handle that case, but cudnn can't # Aesara can handle that case, but cudnn can't
......
...@@ -969,11 +969,16 @@ def test_nnet(): ...@@ -969,11 +969,16 @@ def test_nnet():
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.softmax(x) out = aet_nnet.logsoftmax(x)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = aet_nnet.logsoftmax(x)
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = aet_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
...@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc): ...@@ -1894,20 +1894,27 @@ def test_Dot(x, y, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "x, axis, exc",
[ [
( (
set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)), set_test_value(aet.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None,
),
(
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
), ),
( (
set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), set_test_value(aet.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None, None,
), ),
], ],
) )
def test_Softmax(x, exc): def test_Softmax(x, axis, exc):
g = nnetb.Softmax()(x) g = nnetb.Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
from contextlib import ExitStack as does_not_raise
import numpy as np import numpy as np
import pytest import pytest
import scipy.special as sp
import aesara import aesara
import aesara.tensor as aet import aesara.tensor as aet
...@@ -49,13 +52,13 @@ from aesara.tensor.nnet.basic import ( ...@@ -49,13 +52,13 @@ from aesara.tensor.nnet.basic import (
selu, selu,
sigmoid_binary_crossentropy, sigmoid_binary_crossentropy,
softmax, softmax,
softmax_grad, softmax_grad_legacy,
softmax_graph, softmax_graph,
softmax_op, softmax_legacy,
softmax_with_bias, softmax_with_bias,
softsign, softsign,
) )
from aesara.tensor.shape import shape_padleft, specify_shape from aesara.tensor.shape import shape_padleft
from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.subtensor import AdvancedSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
dmatrix, dmatrix,
...@@ -67,6 +70,7 @@ from aesara.tensor.type import ( ...@@ -67,6 +70,7 @@ from aesara.tensor.type import (
matrices, matrices,
matrix, matrix,
scalar, scalar,
tensor3,
tensor4, tensor4,
vector, vector,
vectors, vectors,
...@@ -80,46 +84,64 @@ from tests.tensor.utils import ( ...@@ -80,46 +84,64 @@ from tests.tensor.utils import (
) )
class TestSoftmax(utt.InferShapeTester): def valid_axis_tester(Op):
def test_basic(self): with pytest.raises(TypeError):
def f(a): Op(1.5)
return softmax_op(a)[:, 0]
utt.verify_grad(f, [np.random.random((3, 4))]) x = [tensor3()] * Op.nin
with does_not_raise():
Op(2)(*x)
def f(a): with pytest.raises(ValueError):
return softmax_op(a)[:, 1] Op(3)(*x)
utt.verify_grad(f, [np.random.random((3, 4))]) with does_not_raise():
Op(-3)(*x)
def f(a): with pytest.raises(ValueError):
return softmax_op(a)[:, 2] Op(-4)(*x)
utt.verify_grad(f, [np.random.random((3, 4))])
class TestSoftmax(utt.InferShapeTester):
@pytest.mark.parametrize("axis", [None, 0, 1, 2, 3, -1, -2])
def test_perform(self, axis):
x = tensor4("x")
xv = np.random.randn(2, 3, 4, 5).astype(config.floatX)
f = aesara.function([x], softmax(x, axis=axis))
assert np.allclose(f(xv), sp.softmax(xv, axis=axis))
@pytest.mark.parametrize("column", [0, 1, 2, 3])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_grad(self, axis, column):
def f(a): def f(a):
return softmax_op(a)[:, 3] return softmax(a, axis=axis)[:, column]
utt.verify_grad(f, [np.random.random((3, 4))]) utt.verify_grad(f, [np.random.random((3, 4, 2))])
def test_infer_shape(self): def test_infer_shape(self):
admat = matrix() admat = matrix()
admat_val = np.random.random((3, 4)).astype(config.floatX) admat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check([admat], [Softmax()(admat)], [admat_val], Softmax) self._compile_and_check(
[admat], [Softmax(axis=-1)(admat)], [admat_val], Softmax
)
def test_vector(self): def test_vector_perform(self):
x = vector() x = vector()
f = aesara.function([x], softmax_op(x)) f = aesara.function([x], softmax(x, axis=None))
xv = np.random.randn(6).astype(config.floatX) xv = np.random.randn(6).astype(config.floatX)
assert np.allclose(f(xv), np.exp(xv) / np.exp(xv).sum()) assert np.allclose(f(xv), sp.softmax(xv))
def test_vector_grad(self): def test_vector_grad(self):
def f(a): def f(a):
return softmax_op(a) return softmax(a, axis=None)
utt.verify_grad(f, [np.random.random((4))]) utt.verify_grad(f, [np.random.random((4))])
def test_valid_axis(self):
valid_axis_tester(Softmax)
class TestSoftmaxWithBias(utt.InferShapeTester): class TestSoftmaxWithBias(utt.InferShapeTester):
def test_basic(self): def test_basic(self):
...@@ -154,10 +176,10 @@ class TestSoftmaxWithBias(utt.InferShapeTester): ...@@ -154,10 +176,10 @@ class TestSoftmaxWithBias(utt.InferShapeTester):
W = aesara.shared(value=initial_W, name="W") W = aesara.shared(value=initial_W, name="W")
vbias = aesara.shared(value=0.1, name="vbias") # 0.01 vbias = aesara.shared(value=0.1, name="vbias") # 0.01
hid = vector("hid") hid = vector("hid")
f = aesara.function([hid], softmax_op(dot(hid, W.T) + vbias)) f = aesara.function([hid], softmax_legacy(dot(hid, W.T) + vbias))
ops = [node.op for node in f.maker.fgraph.toposort()] ops = [node.op for node in f.maker.fgraph.toposort()]
assert softmax_with_bias not in ops assert softmax_with_bias not in ops
assert softmax_op in ops assert softmax_legacy in ops
f([0, 1, 0]) f([0, 1, 0])
# print f.maker.fgraph.toposort() # print f.maker.fgraph.toposort()
...@@ -315,17 +337,19 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -315,17 +337,19 @@ class TestLogSoftmax(utt.InferShapeTester):
g = grad(y.sum(), x) g = grad(y.sum(), x)
softmax_grad_node = g.owner softmax_grad_node = g.owner
assert softmax_grad_node.op == softmax_grad assert softmax_grad_node.op == softmax_grad_legacy
true_div_node = softmax_grad_node.inputs[0].owner true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add. # We replace the elemwise true_div op by an elemwise add.
new_g = softmax_grad(add(*true_div_node.inputs), softmax_grad_node.inputs[1]) new_g = softmax_grad_legacy(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g]) fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert softmax_grad in [n.op for n in fgraph.toposort()] assert softmax_grad_legacy in [n.op for n in fgraph.toposort()]
class TestSoftmaxGrad(utt.InferShapeTester): class TestSoftmaxGrad(utt.InferShapeTester):
...@@ -336,11 +360,14 @@ class TestSoftmaxGrad(utt.InferShapeTester): ...@@ -336,11 +360,14 @@ class TestSoftmaxGrad(utt.InferShapeTester):
bdmat_val = np.random.random((3, 4)).astype(config.floatX) bdmat_val = np.random.random((3, 4)).astype(config.floatX)
self._compile_and_check( self._compile_and_check(
[admat, bdmat], [admat, bdmat],
[SoftmaxGrad()(admat, bdmat)], [SoftmaxGrad(axis=-1)(admat, bdmat)],
[admat_val, bdmat_val], [admat_val, bdmat_val],
SoftmaxGrad, SoftmaxGrad,
) )
def test_valid_axis(self):
valid_axis_tester(SoftmaxGrad)
class TestCrossEntropySoftmax1Hot: class TestCrossEntropySoftmax1Hot:
def test_basic(self): def test_basic(self):
...@@ -611,17 +638,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -611,17 +638,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
# xe = op(x, one_of_n) # xe = op(x, one_of_n)
fgraph = FunctionGraph([x, one_of_n], [op(softmax_op(x), one_of_n)]) fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_vector(self):
x = vector("x")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, one_of_n], [op(softmax_op(x), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
...@@ -633,7 +650,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -633,7 +650,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
one_of_n = lvector("one_of_n") one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_op(x + b), one_of_n)]) fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
...@@ -649,7 +666,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -649,7 +666,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
fgraph = FunctionGraph( fgraph = FunctionGraph(
[x, b, c, one_of_n], [op(softmax_op(add(x, b, c)), one_of_n)] [x, b, c, one_of_n], [op(softmax_legacy(add(x, b, c)), one_of_n)]
) )
assert fgraph.outputs[0].owner.op == op assert fgraph.outputs[0].owner.op == op
...@@ -658,29 +675,17 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -658,29 +675,17 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_optimizations_w_bias_vector(self):
x = vector("x")
b = vector("b")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_op(x + b), one_of_n)])
assert fgraph.outputs[0].owner.op == op
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert len(fgraph.toposort()) == 2
assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias
def test_softmax_grad_optimizations(self): def test_softmax_grad_optimizations(self):
x = matrix("x") x = matrix("x")
one_of_n = lvector("one_of_n") one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot op = crossentropy_categorical_1hot
xe = op(softmax_op(x), one_of_n) xe = op(softmax_legacy(x), one_of_n)
sum_xe = aet_sum(xe) sum_xe = aet_sum(xe)
g_x = grad(sum_xe, x) g_x = grad(sum_xe, x)
fgraph = FunctionGraph([x, one_of_n], [g_x]) fgraph = FunctionGraph([x, one_of_n], [g_x])
assert check_stack_trace( assert check_stack_trace(
fgraph, ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_op] fgraph,
ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy],
) )
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
...@@ -688,25 +693,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -688,25 +693,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops assert crossentropy_softmax_argmax_1hot_with_bias not in ops
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
def test_softmax_grad_optimizations_vector(self):
x = vector("x")
one_of_n = lvector("one_of_n")
op = crossentropy_categorical_1hot
xe = op(softmax_op(x), one_of_n)
sum_xe = aet_sum(xe)
g_x = grad(sum_xe, x)
fgraph = FunctionGraph([x, one_of_n], [g_x])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias not in ops
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
def test_get_rid_of_advanced_indexing_version_of_xent(self): def test_get_rid_of_advanced_indexing_version_of_xent(self):
x = matrix("x") x = matrix("x")
...@@ -737,8 +725,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -737,8 +725,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 2 assert len(ops) == 2
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
# Test that a biased softmax is optimized correctly # Test that a biased softmax is optimized correctly
bias_expressions = [ bias_expressions = [
...@@ -763,7 +751,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -763,7 +751,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(ops) == 2 assert len(ops) == 2
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops assert softmax_with_bias in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
# Test that using "mean" instead of sum works, too # Test that using "mean" instead of sum works, too
mean_expressions = [ mean_expressions = [
...@@ -791,8 +779,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -791,8 +779,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
# there's an extra dimshuffle in there # there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it # but I can't think of a good rule to get rid of it
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
mean_bias_expressions = [ mean_bias_expressions = [
mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])), mean(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
...@@ -818,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -818,7 +806,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert len(ops) == 5 assert len(ops) == 5
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops assert softmax_with_bias in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
def test_xent_thing_int32(self): def test_xent_thing_int32(self):
x = matrix("x") x = matrix("x")
...@@ -847,141 +835,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -847,141 +835,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 3 assert len(ops) == 3
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
def test_optimize_xent_vector(self):
x = vector("x")
y = lvector("y")
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x)[aet.arange(y.shape[0]), y])),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops
assert softmax_grad not in ops
def test_optimize_xent_vector2(self):
x = vector("x")
b = vector("b")
y = lvector("y")
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle]
assert len(ops) == 3
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
def test_optimize_xent_vector3(self):
# Same as test_optimize_xent_vector2, but y is the result of
# a "flatten", and it used to make the constant-folding
# of arange(y.shape[0]) happen before the xent optimization
x = vector("x")
b = vector("b")
y_ = lvector("y_")
y = y_.flatten()
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y_], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle, flatten]
assert len(ops) <= 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
def test_optimize_xent_vector4(self):
# Same as test_optimize_xent_vector2, but y is the result of a
# "specify_shape" that indicates its length is 1, so the
# constant-folding of arange(y.shape[0]) happen before the xent
# optimization
x = vector("x")
b = vector("b")
y_ = lvector("y_")
y = specify_shape(y_, (1,))
# Test that a biased softmax is optimized correctly
bias_expressions = [
aet_sum(-log(softmax(x + b)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(b + x)[aet.arange(y.shape[0]), y])),
-aet_sum(log(softmax(x + b))[aet.arange(y.shape[0]), y]),
aet_sum(-log(softmax(b + x))[aet.arange(y.shape[0]), y]),
]
for expr in bias_expressions:
fgraph = FunctionGraph([x, b, y_], [expr])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
# [big_op, sum, dim_shuffle, specify_shape]
assert len(ops) <= 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self): def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
x = matrix("x") x = matrix("x")
...@@ -996,9 +851,9 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -996,9 +851,9 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
for node in func.maker.fgraph.toposort(): for node in func.maker.fgraph.toposort():
if node.op == crossentropy_softmax_1hot_with_bias_dx: if node.op == crossentropy_softmax_1hot_with_bias_dx:
has_cx1hotdx = True has_cx1hotdx = True
if node.op == softmax_op: if node.op == softmax_legacy:
has_softmax = True has_softmax = True
if node.op == softmax_grad: if node.op == softmax_grad_legacy:
has_softmaxdx = True has_softmaxdx = True
assert has_cx1hotdx assert has_cx1hotdx
...@@ -1033,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -1033,7 +888,7 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert softmax_op not in ops assert softmax_legacy not in ops
# Verify the gradient wrt x # Verify the gradient wrt x
fgraph = FunctionGraph([x, y, a], [grad(expr, x)]) fgraph = FunctionGraph([x, y, a], [grad(expr, x)])
...@@ -1043,8 +898,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -1043,8 +898,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
# Verify the gradient when providing output gradient # Verify the gradient when providing output gradient
fgraph = FunctionGraph( fgraph = FunctionGraph(
...@@ -1056,13 +911,13 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -1056,13 +911,13 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
ops = {node.op for node in fgraph.toposort()} ops = {node.op for node in fgraph.toposort()}
assert crossentropy_softmax_1hot_with_bias_dx in ops assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_op in ops assert softmax_legacy in ops
assert softmax_grad not in ops assert softmax_grad_legacy not in ops
def test_argmax_pushdown(): def test_argmax_pushdown():
x = matrix() x = matrix()
for sm in [softmax_graph, softmax_op]: for sm in [softmax_graph, softmax_legacy]:
# test that the max_and_argmax is pushed down if the max is not used # test that the max_and_argmax is pushed down if the max is not used
out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1]
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
...@@ -1188,12 +1043,12 @@ class TestSoftmaxOpt: ...@@ -1188,12 +1043,12 @@ class TestSoftmaxOpt:
# test that function contains softmax and no div. # test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode) f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_op) assert check_stack_trace(f, ops_to_check=softmax_legacy)
f_ops = [n.op for n in f.maker.fgraph.toposort()] f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1 assert len(f_ops) == 1
assert softmax_op in f_ops assert softmax_legacy in f_ops
f(self.rng.random((3, 4)).astype(config.floatX)) f(self.rng.random((3, 4)).astype(config.floatX))
...@@ -1204,12 +1059,12 @@ class TestSoftmaxOpt: ...@@ -1204,12 +1059,12 @@ class TestSoftmaxOpt:
# test that function contains softmax and no div. # test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode) f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_op) assert check_stack_trace(f, ops_to_check=softmax_legacy)
f_ops = [n.op for n in f.maker.fgraph.toposort()] f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1 assert len(f_ops) == 1
assert softmax_op in f_ops assert softmax_legacy in f_ops
f(self.rng.random((3, 4)).astype(config.floatX)) f(self.rng.random((3, 4)).astype(config.floatX))
...@@ -1226,8 +1081,8 @@ class TestSoftmaxOpt: ...@@ -1226,8 +1081,8 @@ class TestSoftmaxOpt:
g_ops = [n.op for n in g.maker.fgraph.toposort()] g_ops = [n.op for n in g.maker.fgraph.toposort()]
assert len(g_ops) == 2 assert len(g_ops) == 2
assert softmax_op in g_ops assert softmax_legacy in g_ops
assert softmax_grad in g_ops assert softmax_grad_legacy in g_ops
g(self.rng.random((3, 4)), self.rng.uniform(0.5, 1, (3, 4))) g(self.rng.random((3, 4)), self.rng.uniform(0.5, 1, (3, 4)))
...@@ -1272,7 +1127,7 @@ def test_grad_softmax_grad(): ...@@ -1272,7 +1127,7 @@ def test_grad_softmax_grad():
x = aesara.shared(rng.normal(size=(3, 4))) x = aesara.shared(rng.normal(size=(3, 4)))
def f(inputs): def f(inputs):
y = softmax_op(x) y = softmax_legacy(x)
return aesara.grad(None, x, known_grads={y: inputs}) return aesara.grad(None, x, known_grads={y: inputs})
utt.verify_grad(f, [rng.random((3, 4))]) utt.verify_grad(f, [rng.random((3, 4))])
......
...@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker): ...@@ -384,8 +384,7 @@ class TestRopLop(RopLopChecker):
self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(self.mx.sum(axis=1), (self.mat_in_shape[0],))
def test_softmax(self): def test_softmax(self):
# Softmax adds an extra dimnesion ! self.check_rop_lop(aesara.tensor.nnet.softmax(self.x), self.in_shape)
self.check_rop_lop(aesara.tensor.nnet.softmax(self.x)[0], self.in_shape[0])
def test_alloc(self): def test_alloc(self):
# Alloc of the sum of x into a vector # Alloc of the sum of x into a vector
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论