提交 51b5463f authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

improve error checking for axis arguments

1. make sure they are tuples 2. was checking if the max axis was > available dim, should have been >= available dim
上级 7f49bd67
...@@ -7261,31 +7261,49 @@ def tensordot(a, b, axes = 2): ...@@ -7261,31 +7261,49 @@ def tensordot(a, b, axes = 2):
# if 'axes' is a list, transpose a and b such that the summed axes of a # if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first. # are last and the summed axes of b are first.
else: else:
a_axes, b_axes = tuple(axes[0]), tuple(axes[1]) #get first axis element as a tuple
try:
a_axes = tuple(axes[0])
except TypeError:
a_axes = tuple([axes[0]])
#get second axis element as a tuple
try:
b_axes = tuple(axes[1])
except TypeError:
b_axes = tuple([axes[1]])
# the two axes lists must have the same length
if len(a_axes) != len(b_axes):
raise ValueError('Axes elements must have the same length.')
# check that axes is valid given dimension of a and b # check that there aren't more axes than a has dimensions
if len(a_axes) > a.ndim: if len(a_axes) > a.ndim:
raise ValueError('axes[0] should be array_like, of length ' raise ValueError('axes[0] should be array_like with length '
'smaller than the dimension of a ' 'less than the dimensions of a '
'(a.ndim=%i, len(axes[0])=%i).' % '(a.ndim=%i, len(axes[0])=%i).' %
(a.ndim, a_axes)) (a.ndim, len(a_axes)))
if numpy.max(numpy.array(a_axes)) > a.ndim:
raise ValueError('axes[0] contains dimensions higher than a.ndim ' # check that a_axes doesn't contain an axis greater than or equal to
'(a.ndim=%i, max(axes[0])=%i).' % # a's dimensions.
if numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes)))) (a.ndim, numpy.max(numpy.array(a_axes))))
# check that there aren't more axes than b has dimensions
if len(b_axes) > b.ndim: if len(b_axes) > b.ndim:
raise ValueError('axes[1] should be array_like, of length ' raise ValueError('axes[1] should be array_like, of length '
'smaller than the dimension of b ' 'smaller than the dimension of b '
'(a.ndim=%i, len(axes[0])=%i).' % '(a.ndim=%i, len(axes[0])=%i).' %
(b.ndim, b_axes)) (b.ndim, len(b_axes)))
if numpy.max(numpy.array(b_axes)) > b.ndim:
raise ValueError('axes[1] contains dimensions higher than b.ndim '
'(b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
# the two axes lists must have the same length # check that b_axes doesn't contain an axis greater than or equal to
if len(a_axes) != len(b_axes): # b's dimensions.
raise ValueError('Axes elements must have the same length.') if numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
a_order = (tuple(x for x in tuple(xrange(a.ndim)) if x not in a_axes) a_order = (tuple(x for x in tuple(xrange(a.ndim)) if x not in a_axes)
+ a_axes) + a_axes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论