提交 1a305253 authored 作者: gdesjardins's avatar gdesjardins

Made tensordot interface the same as the documentation.

Also fixed tensordot to accept numpy syntax for axes.
上级 e01051f3
...@@ -3841,19 +3841,7 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1, ...@@ -3841,19 +3841,7 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1,
######################### #########################
class TensorDotGrad(Op): class TensorDotGrad(Op):
def __init__(self, axes): def __init__(self, axes):
if isinstance(axes,list): self.axes = TensorDot.parse_axes(axes)
for i,a in enumerate(axes):
if isinstance(a,list):
axes[i]=tuple(a)
axes=tuple(axes)
if isinstance(axes, tuple):
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2")
if len(axes[0])!=len(axes[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size")
assert len(axes[0])==len(axes[1])
self.axes = axes;
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes return type(self) == type(other) and self.axes == other.axes
...@@ -3903,20 +3891,31 @@ class TensorDot(Op): ...@@ -3903,20 +3891,31 @@ class TensorDot(Op):
""" """
def __init__(self, axes): @classmethod
if isinstance(axes,list): def parse_axes(cls, axes):
if not numpy.isscalar(axes) and len(axes)!=2:
raise ValueError("Axes should be scalar valued or a list/tuple of len 2.")
if isinstance(axes,(list,tuple)):
axes_out = []
# cast axes[0] and axes[1] to tuples
for i,a in enumerate(axes): for i,a in enumerate(axes):
if isinstance(a,list): if numpy.isscalar(a):
axes[i]=tuple(a) axes_out.append((a,))
axes=tuple(axes) else:
if isinstance(axes, tuple): axes_out.append(tuple(a))
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2") # these should be of same length
if len(axes[0])!=len(axes[1]): if len(axes_out[0])!=len(axes_out[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size") raise ValueError("Elements of the axes list/tuple need to be of the same size.")
assert len(axes[0])==len(axes[1])
axes = tuple(axes_out)
self.axes = axes
return axes
def __init__(self, axes):
self.axes = self.parse_axes(axes)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes return type(self) == type(other) and self.axes == other.axes
...@@ -3957,7 +3956,40 @@ class TensorDot(Op): ...@@ -3957,7 +3956,40 @@ class TensorDot(Op):
def __str__(self): def __str__(self):
return "tensordot" return "tensordot"
tensordot = TensorDot
def tensordot(x, y, axes=2):
if x.ndim==0 or y.ndim==0:
raise ValueError('Cannot perform tensordot of 0-d inputs.')
axes = TensorDot.parse_axes(axes)
# check whether axes is valid given the dimensions of x and y
if numpy.isscalar(axes):
if axes >= x.ndim or axes >= y.ndim:
raise ValueError('axes should be smaller than the dimension of '\
'x and y (x.ndim=%i, y.ndim=%i)' % (x.ndim,y.ndim))
elif isinstance(axes, (list,tuple)):
if isinstance(axes[0],(list,tuple)) and \
(len(axes[0]) > x.ndim or (numpy.array(axes[0]) >= x.ndim).any()):
raise ValueError('axes[0] should be array_like, of length smaller'\
' than the dimension of x (x.ndim=%i, len(axes[0])=%i).' %
(x.ndim, len(axes[0])))
if isinstance(axes[1],(list,tuple)) and \
(len(axes[1]) > y.ndim or (numpy.array(axes[1]) >= y.ndim).any()):
raise ValueError('axes[1] should be array_like, of length smaller'\
'than the dimension of y (y.ndim=%i, len(axes[1])=%i).' %
(y.ndim, len(axes[1])))
if not hasattr(tensordot, 'op'):
tensordot.op = {}
if axes not in tensordot.op:
tensordot.op[axes] = TensorDot(axes)
return tensordot.op[axes](x, y)
#TODO: tensordot should be function as described in rst docs. #TODO: tensordot should be function as described in rst docs.
......
...@@ -2920,7 +2920,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2920,7 +2920,7 @@ class test_tensordot(unittest.TestCase):
avec = vector() avec = vector()
bvec = vector() bvec = vector()
axes = ((0,),(0,)) axes = ((0,),(0,))
c = tensordot(axes)(avec, bvec) c = tensordot(avec, bvec, axes)
f1 = inplace_func([avec,bvec],c) f1 = inplace_func([avec,bvec],c)
aval = self.rand(5); aval = self.rand(5);
bval = self.rand(5); bval = self.rand(5);
...@@ -2931,7 +2931,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2931,7 +2931,7 @@ class test_tensordot(unittest.TestCase):
# test matrix-vector # test matrix-vector
bmat = matrix() bmat = matrix()
axes = ((0,),(1,)) axes = ((0,),(1,))
c = tensordot(axes)(avec, bmat) c = tensordot(avec, bmat, axes)
f2 = inplace_func([avec,bmat],c) f2 = inplace_func([avec,bmat],c)
aval = self.rand(5); aval = self.rand(5);
bval = self.rand(8,5); bval = self.rand(8,5);
...@@ -2942,7 +2942,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2942,7 +2942,7 @@ class test_tensordot(unittest.TestCase):
# test matrix-matrix # test matrix-matrix
amat = matrix() amat = matrix()
axes = ((1,),(0,)) axes = ((1,),(0,))
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
aval = self.rand(4,7); aval = self.rand(4,7);
bval = self.rand(7,9); bval = self.rand(7,9);
...@@ -2953,7 +2953,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2953,7 +2953,7 @@ class test_tensordot(unittest.TestCase):
# test ndarray-matrix, sum over one dim of matrix # test ndarray-matrix, sum over one dim of matrix
atens = tensor4() atens = tensor4()
axes = ((2,),(1,)) axes = ((2,),(1,))
c = tensordot(axes)(atens, bmat) c = tensordot(atens, bmat, axes)
f4 = inplace_func([atens,bmat],c) f4 = inplace_func([atens,bmat],c)
aval = self.rand(1,2,3,4); aval = self.rand(1,2,3,4);
bval = self.rand(2,3); bval = self.rand(2,3);
...@@ -2965,7 +2965,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2965,7 +2965,7 @@ class test_tensordot(unittest.TestCase):
atens = tensor4() atens = tensor4()
btens = tensor3() btens = tensor3()
axes = ((1,3),(0,2)) axes = ((1,3),(0,2))
c = tensordot(axes)(atens, btens) c = tensordot(atens, btens, axes)
f5 = inplace_func([atens,btens],c) f5 = inplace_func([atens,btens],c)
aval = self.rand(4,3,5,2); aval = self.rand(4,3,5,2);
bval = self.rand(3,4,2); bval = self.rand(3,4,2);
...@@ -2974,47 +2974,59 @@ class test_tensordot(unittest.TestCase): ...@@ -2974,47 +2974,59 @@ class test_tensordot(unittest.TestCase):
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
axes = (axes[1],axes[0]) axes = (axes[1],axes[0])
c = tensordot(axes)(btens, atens) c = tensordot(btens, atens, axes)
f6 = inplace_func([btens,atens],c) f6 = inplace_func([btens,atens],c)
self.failUnless(numpy.allclose(numpy.tensordot(bval,aval,axes), self.failUnless(numpy.allclose(numpy.tensordot(bval,aval,axes),
f6(bval,aval))) f6(bval,aval)))
utt.verify_grad(TensorDot(axes), [bval,aval]) utt.verify_grad(TensorDot(axes), [bval,aval])
def test_raise_error(self): def test_raise_error(self):
amat = matrix()
# test vector-vector bmat = matrix()
avec = vector()
bvec = vector() bvec = vector()
axes = ((0,),())
# test invalid length for axes
try: try:
c = tensordot(axes)(avec, bvec) c = tensordot(amat, bmat, (0,1,2))
assert False assert False
except ValueError: except ValueError:
pass pass
# test matrix-vector
bmat = matrix() # test axes of uneven length
axes = ((0,),())
try: try:
c = tensordot(axes)(avec, bmat) c = tensordot(amat, bmat, ((0,1),(0)))
assert False assert False
except ValueError: except ValueError:
pass pass
# test matrix-matrix # test invalid len(axes) given inputs are matrices
amat = matrix()
axes = ((1,),())
try: try:
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, ((0,1,2),(0,1,2)))
assert False assert False
except ValueError: except ValueError:
pass pass
def test_list(self): # test invalid axes[1] given that y is a vector
try:
c = tensordot(amat, bvec, (0,1))
assert False
except ValueError:
pass
# test invalid scalar axes given inputs are matrices
try:
c = tensordot(amat, bvec, 2)
assert False
except ValueError:
pass
def test_weird_valid_axes(self):
# test matrix-matrix # test matrix-matrix
amat = matrix() amat = matrix()
bmat = matrix() bmat = matrix()
axes = [[1,],[0,]] for axes in 0, (1,0), [1,0], (1,(0,)), ((1,),0), ([1],[0]):
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
aval = self.rand(4,7); aval = self.rand(4,7);
bval = self.rand(7,9); bval = self.rand(7,9);
...@@ -3022,14 +3034,14 @@ class test_tensordot(unittest.TestCase): ...@@ -3022,14 +3034,14 @@ class test_tensordot(unittest.TestCase):
f3(aval,bval))) f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
def test_scalar(self): def test_scalar_axes(self):
# test matrix-matrix # test matrix-matrix
amat = fmatrix() amat = fmatrix()
bmat = dmatrix()#we let at float64 to test mix of float32 and float64. bmat = dmatrix()#we let at float64 to test mix of float32 and float64.
axes = 1 axes = 1
aval = self.rand(4,5) aval = self.rand(4,5)
bval = numpy.random.rand(5,3) bval = numpy.random.rand(5,3)
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes), self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval))) f3(aval,bval)))
...@@ -3041,7 +3053,7 @@ class test_tensordot(unittest.TestCase): ...@@ -3041,7 +3053,7 @@ class test_tensordot(unittest.TestCase):
axes = 2 axes = 2
aval = self.rand(3,4,5) aval = self.rand(3,4,5)
bval = self.rand(4,5,3) bval = self.rand(4,5,3)
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes), self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval))) f3(aval,bval)))
...@@ -3054,7 +3066,7 @@ class test_tensordot(unittest.TestCase): ...@@ -3054,7 +3066,7 @@ class test_tensordot(unittest.TestCase):
axes = 0 axes = 0
aval = self.rand(4,5) aval = self.rand(4,5)
bval = self.rand(5,4) bval = self.rand(5,4)
c = tensordot(axes)(amat, bmat) c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes), self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval))) f3(aval,bval)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论