提交 2eb5c0a1 authored 作者: Florian Bordes's avatar Florian Bordes

Add check_and_normalize_axes function

上级 177cc884
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from six import integer_types
import numpy as np import numpy as np
import theano import theano
from theano import scalar
from theano.compat import izip from theano.compat import izip
from theano.tensor import as_tensor_variable
from theano.tensor.var import TensorConstant
from theano.gof import Variable
from theano.gof.utils import hash_from_code from theano.gof.utils import hash_from_code
from theano.tensor.type_other import NoneConst
integer_dtypes = list(map(str, scalar.integer_types))
def hash_from_ndarray(data): def hash_from_ndarray(data):
...@@ -89,3 +97,48 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -89,3 +97,48 @@ def shape_of_variables(fgraph, input_shapes):
l[var] = tuple(sym_to_num_dict[sym] l[var] = tuple(sym_to_num_dict[sym]
for sym in fgraph.shape_feature.shape_of[var]) for sym in fgraph.shape_feature.shape_of[var])
return l return l
def check_and_normalize_axes(x, axis):
"""
Check axes, normalize and convert them to a Python list of integers.
Return an empty list if argument is None.
Parameters
----------
x: Tensor variable
axis = Integer, tuple or list of integers
Returns
-------
axis: list of integers
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif (isinstance(axis, (integer_types, np.integer)) or
(isinstance(axis, np.ndarray) and axis.ndim == 0)):
axis = [int(axis)]
elif isinstance(axis, (tuple, list, np.ndarray)):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError("Computation needs a constant axis. Got %s" % axis)
else:
assert axis.dtype in integer_dtypes
if (isinstance(axis.data, (integer_types, np.integer)) or
(isinstance(axis.data, np.ndarray) and axis.data.ndim == 0)):
axis = [int(axis.data)]
elif isinstance(axis.data, (list, np.ndarray)):
axis = [int(i) for i in axis.data]
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError("Computation needs a valid axis number for %d-D tensor. Got %d" % (x.type.ndim, axis[i]))
axis = list(set(axis))
axis.sort()
return axis
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论