提交 76609200 authored 作者: Harm de Vries's avatar Harm de Vries

added test for grad/shape

上级 bb1827ef
...@@ -1299,6 +1299,8 @@ class MaxAndArgmax(Op): ...@@ -1299,6 +1299,8 @@ class MaxAndArgmax(Op):
axis = None axis = None
elif isinstance(axis, (int, numpy.integer)): elif isinstance(axis, (int, numpy.integer)):
axis = [int(axis)] axis = [int(axis)]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, Variable): elif isinstance(axis, Variable):
if NoneConst.equals(axis): if NoneConst.equals(axis):
axis = None axis = None
...@@ -1322,23 +1324,25 @@ class MaxAndArgmax(Op): ...@@ -1322,23 +1324,25 @@ class MaxAndArgmax(Op):
axis.sort() axis.sort()
# Verify that axes are valid # Verify that axes are valid
all_axes = set() all_axes = []
if isinstance(axis, list): if isinstance(axis, list):
for ax in axis: for ax in axis:
if ax < 0 or ax >= x.type.ndim: if ax < 0 or ax >= x.type.ndim:
raise ValueError( raise ValueError(
'Invalid axis: %s (the number of dimensions of the ' 'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim)) 'input is: %s)' % (ax, x.type.ndim))
all_axes.add(ax) if ax not in all_axes:
all_axes.append(ax)
else: else:
all_axes = list(range(x.ndim)) all_axes = list(range(x.ndim))
if axis is None or axis == list(range(x.type.ndim)): if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone() axis = NoneConst.clone()
else: else:
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(all_axes)
#assert axis.ndim == 0 assert axis.ndim == 1
inputs = [x, axis] inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which # We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax. # we do not perform the max / argmax.
broadcastable = [b for i, b in enumerate(x.type.broadcastable) broadcastable = [b for i, b in enumerate(x.type.broadcastable)
...@@ -1350,77 +1354,82 @@ class MaxAndArgmax(Op): ...@@ -1350,77 +1354,82 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs): def perform(self, node, inp, outs):
x, axes = inp x, axes = inp
max, max_idx = outs max, max_idx = outs
max[0] = theano._asarray(numpy.max(x, tuple(axes)), if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(axes)
max[0] = theano._asarray(numpy.max(x, axes),
dtype=node.outputs[0].dtype) dtype=node.outputs[0].dtype)
# Numpy does not support multiple axes for argmax # Numpy does not support multiple axes for argmax
# Work around, # Work around
keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes]) keep_axes = numpy.array([i for i in range(x.ndim) if i not in axes])
# Not reduced axes in front # Not-reduced axes in front
transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes))) transposed_x = numpy.transpose(x, numpy.concatenate((keep_axes, axes)))
reshaped_x = x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,)) reshaped_x = transposed_x.reshape(transposed_x.shape[:len(keep_axes)] + (-1,))
max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, -1), dtype='int64')
def c_code(self, node, name, inp, out, sub): max_idx[0] = theano._asarray(numpy.argmax(reshaped_x, axis=-1), dtype='int64')
x, axis = inp
max, argmax = out #def c_code(self, node, name, inp, out, sub):
fail = sub["fail"] #x, axis = inp
if NoneConst.equals(node.inputs[1]): #max, argmax = out
axis_code = "axis = NPY_MAXDIMS;" #fail = sub["fail"]
else: #if NoneConst.equals(node.inputs[1]):
assert node.inputs[1].ndim == 1 #axis_code = "axis = NPY_MAXDIMS;"
# Fall back to perform() if there are multiple axes #else:
if len(node.inputs[1].data) > 1: raise NotImplementedError() #assert node.inputs[1].ndim == 1
axis_code = """ ## Fall back to perform() if there are multiple axes
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0]; #if len(node.inputs[1].data) > 1: raise NotImplementedError()
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ #axis_code = """
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument"); #axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
%(fail)s #if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
} #PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
""" % locals() #%(fail)s
ret = """ #}
int axis; #""" % locals()
#ret = """
Py_CLEAR(%(max)s); #int axis;
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(axis_code)s #Py_CLEAR(%(max)s);
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); #Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
if(%(max)s == NULL){ #%(axis_code)s
PyErr_SetString(PyExc_ValueError, #%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
"MaxAndArgmax, max failed"); #if(%(max)s == NULL){
%(fail)s; #PyErr_SetString(PyExc_ValueError,
} #"MaxAndArgmax, max failed");
if(!PyArray_CheckExact(%(max)s)){ #%(fail)s;
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); #}
if(%(max)s == NULL){ #if(!PyArray_CheckExact(%(max)s)){
%(fail)s; #%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
} #if(%(max)s == NULL){
} #%(fail)s;
#}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL); #}
if(%(argmax)s == NULL){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed"); #%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
Py_CLEAR(%(max)s); #if(%(argmax)s == NULL){
%(fail)s; #PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, argmax failed");
} #Py_CLEAR(%(max)s);
if(!PyArray_CheckExact(%(argmax)s)){ #%(fail)s;
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL); #}
if(%(argmax)s == NULL){ #if(!PyArray_CheckExact(%(argmax)s)){
%(fail)s; #%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
} #if(%(argmax)s == NULL){
} #%(fail)s;
if(PyArray_TYPE(%(argmax)s) != NPY_INT64){ #}
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64); #}
if (NULL == tmp){ #if(PyArray_TYPE(%(argmax)s) != NPY_INT64){
%(fail)s; #PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
} #if (NULL == tmp){
Py_DECREF(%(argmax)s); #%(fail)s;
%(argmax)s = (PyArrayObject*)tmp; #}
} #Py_DECREF(%(argmax)s);
""" #%(argmax)s = (PyArrayObject*)tmp;
return ret % locals() #}
#"""
def c_code_cache_version(self): #return ret % locals()
return (3,)
#def c_code_cache_version(self):
#return (3,)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, axis_shape = shapes
...@@ -1489,7 +1498,6 @@ class MaxAndArgmax(Op): ...@@ -1489,7 +1498,6 @@ class MaxAndArgmax(Op):
else: else:
axis_ = axis axis_ = axis
xmax = max(x, axis_) xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
...@@ -1638,7 +1646,6 @@ def argmax(x, axis=None, keepdims=False): ...@@ -1638,7 +1646,6 @@ def argmax(x, axis=None, keepdims=False):
# In python (using MaxAndArgmax.perform()) this leads to a wasteful # In python (using MaxAndArgmax.perform()) this leads to a wasteful
# implementation that goes through the data twice instead of once # implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
argout = max_and_argmax(x, axis)[1] argout = max_and_argmax(x, axis)[1]
if keepdims: if keepdims:
......
...@@ -2951,10 +2951,15 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2951,10 +2951,15 @@ class T_max_and_argmax(unittest.TestCase):
# Test 4d inner dimensions # Test 4d inner dimensions
data = rand(2, 3, 4, 5) data = rand(2, 3, 4, 5)
for i in [0, 1, 2, 3]: for i in [0, 1, 2, 3]:
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
# Test grad with multiple axes
for i in [[0, 1], [0, 0]]:
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data])
def test_preserve_broadcastable(self): def test_preserve_broadcastable(self):
""" """
...@@ -2965,11 +2970,15 @@ class T_max_and_argmax(unittest.TestCase): ...@@ -2965,11 +2970,15 @@ class T_max_and_argmax(unittest.TestCase):
assert y.type.broadcastable == (True, True, False, True) assert y.type.broadcastable == (True, True, False, True)
def test_multiple_axes(self): def test_multiple_axes(self):
data = as_tensor_variable(numpy.arange(24).reshape(3, 2, 4)) data = numpy.arange(24).reshape(3, 2, 4)
v, i = eval_outputs(max_and_argmax(data, [1, -1])) x = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(x, [1, -1]))
assert numpy.all(v == numpy.array([7, 15, 23])) assert numpy.all(v == numpy.array([7, 15, 23]))
assert numpy.all(i == numpy.array([7, 7, 7])) assert numpy.all(i == numpy.array([7, 7, 7]))
v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape)
assert tuple(v) == numpy.max(data, (1, -1)).shape
class T_argmin_argmax(unittest.TestCase): class T_argmin_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论