提交 d3a4df3d authored 作者: Florian Bordes's avatar Florian Bordes

Move check_and_normalize_axes to Utilies section in basic.py

上级 3e57995d
...@@ -18,7 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable ...@@ -18,7 +18,6 @@ from theano.gof import Apply, Constant, Op, Variable
from theano.gof.type import Generic from theano.gof.type import Generic
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.utils import check_and_normalize_axes
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstantSignature, TensorConstant, TensorConstantSignature,
_tensor_py_operators) _tensor_py_operators)
...@@ -969,6 +968,53 @@ def _pack(x): ...@@ -969,6 +968,53 @@ def _pack(x):
return [x] return [x]
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError("Computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
axis = []
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError("Computation needs a valid axis number for %d-D tensor. Got %d" % (x.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
return axis
######################### #########################
# Casting Operations # Casting Operations
######################### #########################
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from six import integer_types
import numpy as np import numpy as np
import theano import theano
from theano import scalar
from theano.compat import izip from theano.compat import izip
from theano.gof import Variable
from theano.gof.utils import hash_from_code from theano.gof.utils import hash_from_code
from theano.tensor.type_other import NoneConst
integer_dtypes = list(map(str, scalar.integer_types))
def hash_from_ndarray(data): def hash_from_ndarray(data):
...@@ -95,50 +89,3 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -95,50 +89,3 @@ def shape_of_variables(fgraph, input_shapes):
l[var] = tuple(sym_to_num_dict[sym] l[var] = tuple(sym_to_num_dict[sym]
for sym in fgraph.shape_feature.shape_of[var]) for sym in fgraph.shape_feature.shape_of[var])
return l return l
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = tensor.as_tensor_variable(x)
if axis is None:
axis = []
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, tensor.var.TensorConstant):
raise TypeError("Computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
else:
axis = []
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError("Computation needs a valid axis number for %d-D tensor. Got %d" % (x.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
return axis
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论