提交 76609200 authored 作者: Harm de Vries's avatar Harm de Vries

added test for grad/shape

上级 bb1827ef
......@@ -1299,6 +1299,8 @@ class MaxAndArgmax(Op):
axis = None
elif isinstance(axis, (int, numpy.integer)):
axis = [int(axis)]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = None
......@@ -1322,23 +1324,25 @@ class MaxAndArgmax(Op):
axis.sort()
# Verify that axes are valid
all_axes = set()
all_axes = []
if isinstance(axis, list):
for ax in axis:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim))
all_axes.add(ax)
if ax not in all_axes:
all_axes.append(ax)
else:
all_axes = list(range(x.ndim))
if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone()
else:
axis = _as_tensor_variable(axis)
#assert axis.ndim == 0
axis = _as_tensor_variable(all_axes)
assert axis.ndim == 1
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
broadcastable = [b for i, b in enumerate(x.type.broadcastable)
......@@ -1350,77 +1354,82 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs):
x, axes = inp
max, max_idx = outs
max[0] = theano._asarray(numpy.max(x, tuple(axes)),
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(axes)
max[0] = theano._asarray(numpy.max(x, axes),
dtype=node.outputs[0].dtype)
# Numpy does not support multiple axes for argmax
# Work around,
# Work around
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes])
# Not reduced axes in front
# Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
reshaped_x = x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,))
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, -1), dtype='int64')
reshaped_x = transposed_x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,))
def c_code(self, node, name, inp, out, sub):
x, axis = inp
max, argmax = out
fail = sub["fail"]
if NoneConst.equals(node.inputs[1]):
axis_code = "axis = NPY_MAXDIMS;"
else:
assert node.inputs[1].ndim == 1
# Fall back to perform() if there are multiple axes
if len(node.inputs[1].data) > 1: raise NotImplementedError()
axis_code = """
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
""" % locals()
ret = """
int axis;
Py_CLEAR(%(max)s);
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(axis_code)s
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax, max failed");
%(fail)s;
}
if(!PyArray_CheckExact(%(max)s)){
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if(%(argmax)s == NULL){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
Py_CLEAR(%(max)s);
%(fail)s;
}
if(!PyArray_CheckExact(%(argmax)s)){
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
if (NULL == tmp){
%(fail)s;
}
Py_DECREF(%(argmax)s);
%(argmax)s = (PyArrayObject*)tmp;
}
"""
return ret % locals()
def c_code_cache_version(self):
return (3,)
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64')
#def c_code(self, node, name, inp, out, sub):
#x, axis = inp
#max, argmax = out
#fail = sub["fail"]
#if NoneConst.equals(node.inputs[1]):
#axis_code = "axis = NPY_MAXDIMS;"
#else:
#assert node.inputs[1].ndim == 1
## Fall back to perform() if there are multiple axes
#if len(node.inputs[1].data) > 1: raise NotImplementedError()
#axis_code = """
#axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
#if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
#%(fail)s
#}
#""" % locals()
#ret = """
#int axis;
#Py_CLEAR(%(max)s);
#Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
#%(axis_code)s
#%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
#if(%(max)s == NULL){
#PyErr_SetString(PyExc_ValueError,
#"MaxAndArgmax, max failed");
#%(fail)s;
#}
#if(!PyArray_CheckExact(%(max)s)){
#%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(max)s == NULL){
#%(fail)s;
#}
#}
#%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
#if(%(argmax)s == NULL){
#PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
#Py_CLEAR(%(max)s);
#%(fail)s;
#}
#if(!PyArray_CheckExact(%(argmax)s)){
#%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
#if(%(argmax)s == NULL){
#%(fail)s;
#}
#}
#if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
#PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
#if (NULL == tmp){
#%(fail)s;
#}
#Py_DECREF(%(argmax)s);
#%(argmax)s = (PyArrayObject*)tmp;
#}
#"""
#return ret % locals()
#def c_code_cache_version(self):
#return (3,)
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
......@@ -1489,7 +1498,6 @@ class MaxAndArgmax(Op):
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
......@@ -1638,7 +1646,6 @@ def argmax(x, axis=None, keepdims=False):
# In python (using MaxAndArgmax.perform()) this leads to a wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
argout = max_and_argmax(x, axis)[1]
if keepdims:
......
......@@ -2951,10 +2951,15 @@ class T_max_and_argmax(unittest.TestCase):
# Test 4d inner dimensions
data = rand(2, 3, 4, 5)
for i in [0, 1, 2, 3]:
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
# Test grad with multiple axes
for i in [[0, 1], [0, 0]]:
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data])
def test_preserve_broadcastable(self):
"""
......@@ -2965,11 +2970,15 @@ class T_max_and_argmax(unittest.TestCase):
assert y.type.broadcastable == (True, True, False, True)
def test_multiple_axes(self):
data = as_tensor_variable(numpy.arange(24).reshape(3, 2, 4))
v, i = eval_outputs(max_and_argmax(data, [1, -1]))
data = numpy.arange(24).reshape(3, 2, 4)
x = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(x, [1, -1]))
assert numpy.all(v == numpy.array([7, 15, 23]))
assert numpy.all(i == numpy.array([7, 7, 7]))
v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape)
assert tuple(v) == numpy.max(data, (1, -1)).shape
class T_argmin_argmax(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论