提交 67f13b01 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

check for 0-length axes to avoid error with numpy.max; also unit test

上级 63e69e9e
......@@ -7289,8 +7289,9 @@ def tensordot(a, b, axes = 2):
(a.ndim, len(a_axes)))
# check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions.
if numpy.max(numpy.array(a_axes)) >= a.ndim:
# a's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if len(a_axes) > 0 and numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes))))
......@@ -7303,8 +7304,9 @@ def tensordot(a, b, axes = 2):
(b.ndim, len(b_axes)))
# check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions.
if numpy.max(numpy.array(b_axes)) >= b.ndim:
# b's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if len(b_axes) > 0 and numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes))))
......
......@@ -5541,7 +5541,13 @@ class test_tensordot(unittest.TestCase):
# Test matrix-matrix
amat = matrix()
bmat = matrix()
for axes in 0, (1, 0), [1, 0], (1, (0, )), ((1, ), 0), ([1], [0]):
for axes in [0,
(1, 0),
[1, 0],
(1, (0, )),
((1, ), 0),
([1], [0]),
([], [])]:
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c)
aval = rand(4, 7)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论