提交 3e57995d authored 作者: Florian Bordes's avatar Florian Bordes

Use check_and_normalize_axes for Argmax

上级 2eb5c0a1
...@@ -18,6 +18,7 @@ from theano.gof import Apply, Constant, Op, Variable ...@@ -18,6 +18,7 @@ from theano.gof import Apply, Constant, Op, Variable
from theano.gof.type import Generic from theano.gof.type import Generic
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.utils import check_and_normalize_axes
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstantSignature, TensorConstant, TensorConstantSignature,
_tensor_py_operators) _tensor_py_operators)
...@@ -1382,54 +1383,13 @@ class Argmax(Op): ...@@ -1382,54 +1383,13 @@ class Argmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
# Check axis and convert it to a Python list of integers.
if isinstance(axis, (integer_types, np.integer)): axis = check_and_normalize_axes(x, axis)
axis = [int(axis)] if len(axis) == 0:
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(a) for a in axis]
if axis == list(range(x.type.ndim)):
axis = None
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = None
elif not isinstance(axis, TensorConstant):
raise TypeError(
"Argmax needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (integer_types, np.integer)) or \
(isinstance(axis.data, np.ndarray) and
axis.data.ndim == 0):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
# Make axis entries non-negative, and sort them
if isinstance(axis, list):
for idx in xrange(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
# Verify that axes are valid
all_axes = []
if isinstance(axis, list):
for ax in axis:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim))
if ax not in all_axes:
all_axes.append(ax)
else:
all_axes = list(range(x.ndim))
if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone() axis = NoneConst.clone()
all_axes = list(range(x.ndim))
else: else:
axis = _as_tensor_variable(all_axes) axis = _as_tensor_variable(axis)
assert axis.ndim == 1 assert axis.ndim == 1
inputs = [x, axis] inputs = [x, axis]
...@@ -1591,36 +1551,9 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -1591,36 +1551,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
# Check axis and convert it to a Python list of integers. # Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of MaxAndArgmax. # Axis will be used as an op param of MaxAndArgmax.
a = as_tensor_variable(a) a = as_tensor_variable(a)
if axis is None: axis = check_and_normalize_axes(a, axis)
axis = list(range(a.type.ndim))
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = list(range(a.type.ndim))
elif not isinstance(axis, TensorConstant):
raise TypeError("max and argmax computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
if len(axis) == 0: if len(axis) == 0:
axis = list(range(a.type.ndim)) axis = list(range(a.type.ndim))
else:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += a.type.ndim
if axis[i] < 0 or axis[i] >= a.type.ndim:
raise ValueError("max and argmax computation needs a valid axis number for %d-D tensor. Got %d"
% (a.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
out, argout = MaxAndArgmax(axis)(a) out, argout = MaxAndArgmax(axis)(a)
if keepdims: if keepdims:
......
...@@ -5,8 +5,6 @@ import numpy as np ...@@ -5,8 +5,6 @@ import numpy as np
import theano import theano
from theano import scalar from theano import scalar
from theano.compat import izip from theano.compat import izip
from theano.tensor import as_tensor_variable
from theano.tensor.var import TensorConstant
from theano.gof import Variable from theano.gof import Variable
from theano.gof.utils import hash_from_code from theano.gof.utils import hash_from_code
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
...@@ -113,7 +111,7 @@ def check_and_normalize_axes(x, axis): ...@@ -113,7 +111,7 @@ def check_and_normalize_axes(x, axis):
------- -------
axis: list of integers axis: list of integers
""" """
x = as_tensor_variable(x) x = tensor.as_tensor_variable(x)
if axis is None: if axis is None:
axis = [] axis = []
elif (isinstance(axis, (integer_types, np.integer)) or elif (isinstance(axis, (integer_types, np.integer)) or
...@@ -124,7 +122,7 @@ def check_and_normalize_axes(x, axis): ...@@ -124,7 +122,7 @@ def check_and_normalize_axes(x, axis):
elif isinstance(axis, Variable): elif isinstance(axis, Variable):
if NoneConst.equals(axis): if NoneConst.equals(axis):
axis = [] axis = []
elif not isinstance(axis, TensorConstant): elif not isinstance(axis, tensor.var.TensorConstant):
raise TypeError("Computation needs a constant axis. Got %s" % axis) raise TypeError("Computation needs a constant axis. Got %s" % axis)
else: else:
assert axis.dtype in integer_dtypes assert axis.dtype in integer_dtypes
...@@ -133,6 +131,8 @@ def check_and_normalize_axes(x, axis): ...@@ -133,6 +131,8 @@ def check_and_normalize_axes(x, axis):
axis = [int(axis.data)] axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)): elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data] axis = [int(i) for i in axis.data]
else:
axis = []
if len(axis) > 0: if len(axis) > 0:
for i in range(len(axis)): for i in range(len(axis)):
if axis[i] < 0: if axis[i] < 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论