提交 4b2c6694 authored 作者: Frederic's avatar Frederic

Use scalar for axis or None in MaxAndArgmax.

this remove a deprecation warning in NumPy as ndarray won't be accepted as int anymore.
上级 932555b5
......@@ -19,6 +19,7 @@ from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant,
_tensor_py_operators)
from theano.tensor.type import TensorType
from theano.tensor.type_other import NoneConst
from theano import scalar as scal
from theano.gof.python25 import partial, any, all
from theano.gof.utils import hashtype
......@@ -1359,11 +1360,7 @@ class MaxAndArgmax(Op):
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis, (int, numpy.integer)):
axis = [axis]
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list)):
if isinstance(axis, (tuple, list)):
axis = [int(a) for a in axis]
if len(axis) != 1:
axis = list(axis)
......@@ -1376,34 +1373,41 @@ class MaxAndArgmax(Op):
assert axis == range(x.type.ndim), (
"MaxAndArgmax does not support multiple"
" axes. the max fct supports it.")
axis = None
else:
axis = axis[0]
if isinstance(axis, (int, numpy.integer)):
axis = int(axis)
elif isinstance(axis, numpy.ndarray) and axis.ndim == 0:
axis = int(axis)
elif isinstance(axis, Variable):
if not isinstance(axis, TensorConstant):
raise TypeError("MaxAndArgmax needs a constant axis")
axis = axis.data
if axis.ndim == 0:
axis = [axis]
assert axis.dtype.startswith("int") or axis.dtype.startswith("uint")
axis = int(axis.data)
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
for id, a in enumerate(axis):
if not isinstance(a, TensorVariable) and a < 0:
if -a > x.type.ndim:
raise ValueError('axis out of range')
axis[id] = x.type.ndim + a
if axis is None:
axis = _as_tensor_variable(range(x.type.ndim))
else:
axis = _as_tensor_variable(axis)
if axis < 0:
if -axis > x.type.ndim:
raise ValueError('axis out of range')
axis = x.type.ndim + axis
# Verify that the axis is valid.
all_axes = set()
for ax in axis.data:
if ax < 0 or ax >= x.type.ndim:
if axis is not None:
if axis < 0 or axis >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (axis, x.type.ndim))
all_axes.add(ax.item())
assert axis.ndim == 1
all_axes.add(axis)
else:
all_axes = range(x.ndim)
if axis is None:
axis = NoneConst
else:
axis = _as_tensor_variable(axis)
assert axis.ndim == 0
inputs = [x, axis]
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
......@@ -1416,8 +1420,6 @@ class MaxAndArgmax(Op):
def perform(self, node, inp, outs):
x, axis = inp
max, max_idx = outs
if python_all(axis == range(x.ndim)):
axis = None
max[0] = theano._asarray(numpy.max(x, axis),
dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
......@@ -1426,20 +1428,17 @@ class MaxAndArgmax(Op):
x, axis = inp
max, argmax = out
fail = sub["fail"]
assert node.inputs[1].ndim == 1
assert node.inputs[1] is theano.tensor.type_other.NoneConst or node.inputs[1].ndim == 0
ret = """
int axis;
if(PyArray_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)){
if((PyObject*)%(axis)s == Py_None){
axis = NPY_MAXDIMS;
}else if(PyArray_SIZE(%(axis)s) == 1){
}else{
axis = ((dtype_%(axis)s*)PyArray_DATA(%(axis)s))[0];
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s
}
}else{
PyErr_SetString(PyExc_ValueError, "MaxAndArgmax, bad axis argument");
%(fail)s;
}
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if(%(max)s == NULL){
......@@ -1475,7 +1474,7 @@ class MaxAndArgmax(Op):
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis = node.inputs[1]
if python_all(axis.data == range(node.inputs[0].ndim)):
if node.inputs[1].data is None:
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data])
......@@ -1533,12 +1532,16 @@ class MaxAndArgmax(Op):
# the gradient on its inputs is zero
if g_max_disconnected:
return [x.zeros_like(), axis_grad]
xmax = max(x, axis)
if axis is NoneConst:
axis_ = range(x.ndim)
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if python_all(axis.data == range(x.ndim)):
if axis is NoneConst:
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
......
......@@ -46,10 +46,13 @@ def local_max_and_argmax(node):
if len(node.outputs[1].clients) == 0:
#MaxAndArgmax support variable axis,
#but CAReduce support only constant axis.
try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
if node.inputs[1].data is None:
axis = None
else:
try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论