提交 03a31568 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

support weird axes in tensordot: both empty [ [ ], [ ] ]

上级 f4af5d9c
...@@ -6803,7 +6803,7 @@ def tensordot(a, b, axes = 2): ...@@ -6803,7 +6803,7 @@ def tensordot(a, b, axes = 2):
# check that a_axes doesn't contain an axis greater than or equal to # check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions. # a's dimensions.
if numpy.max(numpy.array(a_axes)) >= a.ndim: if len(a_axes) > 0 and numpy.max(numpy.array(a_axes)) >= a.ndim:
raise ValueError('axes[0] contains dimensions greater than or ' raise ValueError('axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' % 'equal to a.ndim (a.ndim=%i, max(axes[0])=%i).' %
(a.ndim, numpy.max(numpy.array(a_axes)))) (a.ndim, numpy.max(numpy.array(a_axes))))
...@@ -6817,7 +6817,7 @@ def tensordot(a, b, axes = 2): ...@@ -6817,7 +6817,7 @@ def tensordot(a, b, axes = 2):
# check that b_axes doesn't contain an axis greater than or equal to # check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions. # b's dimensions.
if numpy.max(numpy.array(b_axes)) >= b.ndim: if len(b_axes) != 0 and numpy.max(numpy.array(b_axes)) >= b.ndim:
raise ValueError('axes[1] contains dimensions greater than or ' raise ValueError('axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' % 'equal to b.ndim (b.ndim=%i, max(axes[1])=%i).' %
(b.ndim, numpy.max(numpy.array(b_axes)))) (b.ndim, numpy.max(numpy.array(b_axes))))
......
...@@ -5471,7 +5471,13 @@ class test_tensordot(unittest.TestCase): ...@@ -5471,7 +5471,13 @@ class test_tensordot(unittest.TestCase):
# Test matrix-matrix # Test matrix-matrix
amat = matrix() amat = matrix()
bmat = matrix() bmat = matrix()
for axes in 0, (1, 0), [1, 0], (1, (0, )), ((1, ), 0), ([1], [0]): for axes in [0,
(1, 0),
[1, 0],
(1, (0, )),
((1, ), 0),
([1], [0]),
([], [])]:
c = tensordot(amat, bmat, axes) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c) f3 = inplace_func([amat, bmat], c)
aval = rand(4, 7) aval = rand(4, 7)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论