提交 4b2c6694 authored 作者: Frederic's avatar Frederic

Use scalar for axis or None in MaxAndArgmax.

this remove a deprecation warning in NumPy as ndarray won't be accepted as int anymore.
上级 932555b5
...@@ -19,6 +19,7 @@ from theano.tensor.var import (AsTensorError, TensorVariable, ...@@ -19,6 +19,7 @@ from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstant,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType from theano.tensor.type import TensorType
from theano.tensor.type_other import NoneConst
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all from theano.gof.python25 import partial, any, all
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
...@@ -1359,11 +1360,7 @@ class MaxAndArgmax(Op): ...@@ -1359,11 +1360,7 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
if isinstance(axis, (int, numpy.integer)): if isinstance(axis, (tuple, list)):
axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list)):
axis = [int(a) for a in axis] axis = [int(a) for a in axis]
if len(axis) != 1: if len(axis) != 1:
axis = list(axis) axis = list(axis)
...@@ -1376,34 +1373,41 @@ class MaxAndArgmax(Op): ...@@ -1376,34 +1373,41 @@ class MaxAndArgmax(Op):
assert axis == range(x.type.ndim), ( assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple" "MaxAndArgmax does not support multiple"
" axes. the max fct supports it.") " axes. the max fct supports it.")
axis = None
else:
axis = axis[0]
if isinstance(axis, (int, numpy.integer)):
axis = int(axis)
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis)
elif isinstance(axis, Variable): elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant): if not isinstance(axis, TensorConstant):
raise TypeError("MaxAndArgmax needs a constant axis") raise TypeError("MaxAndArgmax needs a constant axis")
axis = axis.data assert axis.dtype.startswith("int") or axis.dtype.startswith("uint")
if axis.ndim == 0: axis = int(axis.data)
axis = [axis]
# we make the axis all positive to make the infer_shape work # we make the axis all positive to make the infer_shape work
# with negative axis # with negative axis
if x.type.ndim > 0 and axis is not None: if x.type.ndim > 0 and axis is not None:
for id, a in enumerate(axis): if axis < 0:
if not isinstance(a, TensorVariable) and a < 0: if -axis > x.type.ndim:
if -a > x.type.ndim: raise ValueError('axis out of range')
raise ValueError('axis out of range') axis = x.type.ndim + axis
axis[id] = x.type.ndim + a
if axis is None:
axis = _as_tensor_variable(range(x.type.ndim))
else:
axis = _as_tensor_variable(axis)
# Verify that the axis is valid. # Verify that the axis is valid.
all_axes = set() all_axes = set()
for ax in axis.data: if axis is not None:
if ax < 0 or ax >= x.type.ndim: if axis < 0 or axis >= x.type.ndim:
raise ValueError( raise ValueError(
'Invalid axis: %s (the number of dimensions of the ' 'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim)) 'input is: %s)' % (axis, x.type.ndim))
all_axes.add(ax.item()) all_axes.add(axis)
assert axis.ndim == 1 else:
all_axes = range(x.ndim)
if axis is None:
axis = NoneConst
else:
axis = _as_tensor_variable(axis)
assert axis.ndim == 0
inputs = [x, axis] inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which # We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax. # we do not perform the max / argmax.
...@@ -1416,8 +1420,6 @@ class MaxAndArgmax(Op): ...@@ -1416,8 +1420,6 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs): def perform(self, node, inp, outs):
x, axis = inp x, axis = inp
max, max_idx = outs max, max_idx = outs
if python_all(axis == range(x.ndim)):
axis = None
max[0] = theano._asarray(numpy.max(x, axis), max[0] = theano._asarray(numpy.max(x, axis),
dtype=node.outputs[0].dtype) dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64') max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
...@@ -1426,20 +1428,17 @@ class MaxAndArgmax(Op): ...@@ -1426,20 +1428,17 @@ class MaxAndArgmax(Op):
x, axis = inp x, axis = inp
max, argmax = out max, argmax = out
fail = sub["fail"] fail = sub["fail"]
assert node.inputs[1].ndim == 1 assert node.inputs[1] is theano.tensor.type_other.NoneConst or node.inputs[1].ndim == 0
ret = """ ret = """
int axis; int axis;
if(PyArray_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)){ if((PyObject*)%(axis)s == Py_None){
axis = NPY_MAXDIMS; axis = NPY_MAXDIMS;
}else if(PyArray_SIZE(%(axis)s) == 1){ }else{
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0]; axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument"); PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s %(fail)s
} }
}else{
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s;
} }
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL); %(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){ if(%(max)s == NULL){
...@@ -1475,7 +1474,7 @@ class MaxAndArgmax(Op): ...@@ -1475,7 +1474,7 @@ class MaxAndArgmax(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, axis_shape = shapes
axis = node.inputs[1] axis = node.inputs[1]
if python_all(axis.data == range(node.inputs[0].ndim)): if node.inputs[1].data is None:
return [(), ()] return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate( rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data]) node.inputs[0].type.broadcastable) if i != axis.data])
...@@ -1533,12 +1532,16 @@ class MaxAndArgmax(Op): ...@@ -1533,12 +1532,16 @@ class MaxAndArgmax(Op):
# the gradient on its inputs is zero # the gradient on its inputs is zero
if g_max_disconnected: if g_max_disconnected:
return [x.zeros_like(), axis_grad] return [x.zeros_like(), axis_grad]
xmax = max(x, axis) if axis is NoneConst:
axis_ = range(x.ndim)
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input. # Raise the g_max and xmax to the same number of dim as the input.
pattern = [] pattern = []
out_dim = 0 out_dim = 0
if python_all(axis.data == range(x.ndim)): if axis is NoneConst:
# We are taking the max/argmax over all dimensions. # We are taking the max/argmax over all dimensions.
axis = None axis = None
for i in range(x.ndim): for i in range(x.ndim):
......
...@@ -46,10 +46,13 @@ def local_max_and_argmax(node): ...@@ -46,10 +46,13 @@ def local_max_and_argmax(node):
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
#MaxAndArgmax support variable axis, #MaxAndArgmax support variable axis,
#but CAReduce support only constant axis. #but CAReduce support only constant axis.
try: if node.inputs[1].data is None:
axis = get_scalar_constant_value(node.inputs[1]) axis = None
except NotScalarConstantError: else:
return False try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None] return [new, None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论