提交 1271c0bc authored 作者: Reyhane Askari's avatar Reyhane Askari 提交者: GitHub

Merge pull request #5980 from bordesf/check_axes

Function that check the axes
...@@ -968,6 +968,53 @@ def _pack(x): ...@@ -968,6 +968,53 @@ def _pack(x):
return [x] return [x]
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError("Computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
raise TypeError("Axis must be an integer, tuple, list of integers or a TensorVariable. Got %s" % axis)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError("Computation needs a valid axis number for %d-D tensor. Got %d" % (x.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
return axis
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
...@@ -1382,54 +1429,14 @@ class Argmax(Op): ...@@ -1382,54 +1429,14 @@ class Argmax(Op):
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
# Check axis and convert it to a Python list of integers.
if isinstance(axis, (integer_types, np.integer)): axis = check_and_normalize_axes(x, axis)
axis = [int(axis)] if len(axis) == 0:
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(a) for a in axis]
if axis == list(range(x.type.ndim)):
axis = None
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = None
elif not isinstance(axis, TensorConstant):
raise TypeError(
"Argmax needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, (integer_types, np.integer)) or \
(isinstance(axis.data, np.ndarray) and
axis.data.ndim == 0):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
# Make axis entries non-negative, and sort them
if isinstance(axis, list):
for idx in xrange(len(axis)):
if axis[idx] < 0:
axis[idx] += x.type.ndim
axis.sort()
# Verify that axes are valid
all_axes = []
if isinstance(axis, list):
for ax in axis:
if ax < 0 or ax >= x.type.ndim:
raise ValueError(
'Invalid axis: %s (the number of dimensions of the '
'input is: %s)' % (ax, x.type.ndim))
if ax not in all_axes:
all_axes.append(ax)
else:
all_axes = list(range(x.ndim))
if axis is None or axis == list(range(x.type.ndim)):
axis = NoneConst.clone() axis = NoneConst.clone()
all_axes = list(range(x.ndim))
else: else:
axis = _as_tensor_variable(all_axes) all_axes = axis
axis = _as_tensor_variable(axis)
assert axis.ndim == 1 assert axis.ndim == 1
inputs = [x, axis] inputs = [x, axis]
...@@ -1591,36 +1598,9 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -1591,36 +1598,9 @@ def max_and_argmax(a, axis=None, keepdims=False):
# Check axis and convert it to a Python list of integers. # Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of MaxAndArgmax. # Axis will be used as an op param of MaxAndArgmax.
a = as_tensor_variable(a) a = as_tensor_variable(a)
if axis is None: axis = check_and_normalize_axes(a, axis)
axis = list(range(a.type.ndim))
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = list(range(a.type.ndim))
elif not isinstance(axis, TensorConstant):
raise TypeError("max and argmax computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
if len(axis) == 0: if len(axis) == 0:
axis = list(range(a.type.ndim)) axis = list(range(a.type.ndim))
else:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += a.type.ndim
if axis[i] < 0 or axis[i] >= a.type.ndim:
raise ValueError("max and argmax computation needs a valid axis number for %d-D tensor. Got %d"
% (a.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
out, argout = MaxAndArgmax(axis)(a) out, argout = MaxAndArgmax(axis)(a)
if keepdims: if keepdims:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论