提交 67f13b01 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

check for 0-length axes to avoid error with numpy.max; also unit test

上级 63e69e9e
...@@ -7289,8 +7289,9 @@ def tensordot(a, b, axes = 2): ...@@ -7289,8 +7289,9 @@ def tensordot(a, b, axes = 2):
(a.ndim, len(a_axes))) (a.ndim, len(a_axes)))
# check that a_axes doesn't contain an axis greater than or equal to # check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions. # a's dimensions. also check if len > 0 so numpy.max won't raise an
if numpy.max(numpy.array(a_axes)) >= a.ndim: # error.
if len(a_axes) > 0 and numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or ' raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' % 'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes)))) (a.ndim, numpy.max(numpy.array(a_axes))))
...@@ -7303,8 +7304,9 @@ def tensordot(a, b, axes = 2): ...@@ -7303,8 +7304,9 @@ def tensordot(a, b, axes = 2):
(b.ndim, len(b_axes))) (b.ndim, len(b_axes)))
# check that b_axes doesn't contain an axis greater than or equal to # check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions. # b's dimensions. also check if len > 0 so numpy.max won't raise an
if numpy.max(numpy.array(b_axes)) >= b.ndim: # error.
if len(b_axes) > 0 and numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or ' raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' % 'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes)))) (b.ndim, numpy.max(numpy.array(b_axes))))
......
...@@ -5541,7 +5541,13 @@ class test_tensordot(unittest.TestCase): ...@@ -5541,7 +5541,13 @@ class test_tensordot(unittest.TestCase):
# Test matrix-matrix # Test matrix-matrix
amat = matrix() amat = matrix()
bmat = matrix() bmat = matrix()
for axes in 0, (1, 0), [1, 0], (1, (0, )), ((1, ), 0), ([1], [0]): for axes in [0,
(1, 0),
[1, 0],
(1, (0, )),
((1, ), 0),
([1], [0]),
([], [])]:
c = tensordot(amat, bmat, axes) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
aval = rand(4, 7) aval = rand(4, 7)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论