提交 51b5463f authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

improve error checking for axis arguments

1. make sure they are tuples 2. was checking if the max axis was > available dim, should have been >= available dim
上级 7f49bd67
......@@ -7261,31 +7261,49 @@ def tensordot(a, b, axes = 2):
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else:
a_axes, b_axes = tuple(axes[0]), tuple(axes[1])
#get first axis element as a tuple
try:
a_axes = tuple(axes[0])
except TypeError:
a_axes = tuple([axes[0]])
#get second axis element as a tuple
try:
b_axes = tuple(axes[1])
except TypeError:
b_axes = tuple([axes[1]])
# the two axes lists must have the same length
if len(a_axes) != len(b_axes):
raise ValueError('Axes elements must have the same length.')
# check that axes is valid given dimension of a and b
# check that there aren't more axes than a has dimensions
if len(a_axes) > a.ndim:
raise ValueError('axes[0] should be array_like, of length '
'smaller than the dimension of a '
raise ValueError('axes[0] should be array_like with length '
'less than the dimensions of a '
'(a.ndim=%i, len(axes[0])=%i).' %
(a.ndim, a_axes))
if numpy.max(numpy.array(a_axes)) > a.ndim:
raise ValueError('axes[0] contains dimensions higher than a.ndim '
'(a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, len(a_axes)))
# check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions.
if numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes))))
# check that there aren't more axes than b has dimensions
if len(b_axes) > b.ndim:
raise ValueError('axes[1] should be array_like, of length '
'smaller than the dimension of b '
'(a.ndim=%i, len(axes[0])=%i).' %
(b.ndim, b_axes))
if numpy.max(numpy.array(b_axes)) > b.ndim:
raise ValueError('axes[1] contains dimensions higher than b.ndim '
'(b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
(b.ndim, len(b_axes)))
# the two axes lists must have the same length
if len(a_axes) != len(b_axes):
raise ValueError('Axes elements must have the same length.')
# check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions.
if numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
a_order = (tuple(x for x in tuple(xrange(a.ndim)) if x not in a_axes)
+ a_axes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论