提交 1a305253 authored 作者: gdesjardins's avatar gdesjardins

Made tensordot interface the same as the documentation.

Also fixed tensordot to accept numpy syntax for axes.
上级 e01051f3
......@@ -3841,19 +3841,7 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'], -1,
#########################
class TensorDotGrad(Op):
def __init__(self, axes):
if isinstance(axes,list):
for i,a in enumerate(axes):
if isinstance(a,list):
axes[i]=tuple(a)
axes=tuple(axes)
if isinstance(axes, tuple):
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2")
if len(axes[0])!=len(axes[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size")
assert len(axes[0])==len(axes[1])
self.axes = axes;
self.axes = TensorDot.parse_axes(axes)
def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes
......@@ -3903,20 +3891,31 @@ class TensorDot(Op):
"""
def __init__(self, axes):
if isinstance(axes,list):
@classmethod
def parse_axes(cls, axes):
if not numpy.isscalar(axes) and len(axes)!=2:
raise ValueError("Axes should be scalar valued or a list/tuple of len 2.")
if isinstance(axes,(list,tuple)):
axes_out = []
# cast axes[0] and axes[1] to tuples
for i,a in enumerate(axes):
if isinstance(a,list):
axes[i]=tuple(a)
axes=tuple(axes)
if isinstance(axes, tuple):
if len(axes)!=2:
raise ValueError("We need the list/tuple of axes to be of length 2")
if len(axes[0])!=len(axes[1]):
raise ValueError("We need that the axes 2 sub list of axes are of the same size")
assert len(axes[0])==len(axes[1])
self.axes = axes
if numpy.isscalar(a):
axes_out.append((a,))
else:
axes_out.append(tuple(a))
# these should be of same length
if len(axes_out[0])!=len(axes_out[1]):
raise ValueError("Elements of the axes list/tuple need to be of the same size.")
axes = tuple(axes_out)
return axes
def __init__(self, axes):
self.axes = self.parse_axes(axes)
def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes
......@@ -3957,7 +3956,40 @@ class TensorDot(Op):
def __str__(self):
return "tensordot"
tensordot = TensorDot
def tensordot(x, y, axes=2):
if x.ndim==0 or y.ndim==0:
raise ValueError('Cannot perform tensordot of 0-d inputs.')
axes = TensorDot.parse_axes(axes)
# check whether axes is valid given the dimensions of x and y
if numpy.isscalar(axes):
if axes >= x.ndim or axes >= y.ndim:
raise ValueError('axes should be smaller than the dimension of '\
'x and y (x.ndim=%i, y.ndim=%i)' % (x.ndim,y.ndim))
elif isinstance(axes, (list,tuple)):
if isinstance(axes[0],(list,tuple)) and \
(len(axes[0]) > x.ndim or (numpy.array(axes[0]) >= x.ndim).any()):
raise ValueError('axes[0] should be array_like, of length smaller'\
' than the dimension of x (x.ndim=%i, len(axes[0])=%i).' %
(x.ndim, len(axes[0])))
if isinstance(axes[1],(list,tuple)) and \
(len(axes[1]) > y.ndim or (numpy.array(axes[1]) >= y.ndim).any()):
raise ValueError('axes[1] should be array_like, of length smaller'\
'than the dimension of y (y.ndim=%i, len(axes[1])=%i).' %
(y.ndim, len(axes[1])))
if not hasattr(tensordot, 'op'):
tensordot.op = {}
if axes not in tensordot.op:
tensordot.op[axes] = TensorDot(axes)
return tensordot.op[axes](x, y)
#TODO: tensordot should be function as described in rst docs.
......
......@@ -2920,7 +2920,7 @@ class test_tensordot(unittest.TestCase):
avec = vector()
bvec = vector()
axes = ((0,),(0,))
c = tensordot(axes)(avec, bvec)
c = tensordot(avec, bvec, axes)
f1 = inplace_func([avec,bvec],c)
aval = self.rand(5);
bval = self.rand(5);
......@@ -2931,7 +2931,7 @@ class test_tensordot(unittest.TestCase):
# test matrix-vector
bmat = matrix()
axes = ((0,),(1,))
c = tensordot(axes)(avec, bmat)
c = tensordot(avec, bmat, axes)
f2 = inplace_func([avec,bmat],c)
aval = self.rand(5);
bval = self.rand(8,5);
......@@ -2942,7 +2942,7 @@ class test_tensordot(unittest.TestCase):
# test matrix-matrix
amat = matrix()
axes = ((1,),(0,))
c = tensordot(axes)(amat, bmat)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
aval = self.rand(4,7);
bval = self.rand(7,9);
......@@ -2953,7 +2953,7 @@ class test_tensordot(unittest.TestCase):
# test ndarray-matrix, sum over one dim of matrix
atens = tensor4()
axes = ((2,),(1,))
c = tensordot(axes)(atens, bmat)
c = tensordot(atens, bmat, axes)
f4 = inplace_func([atens,bmat],c)
aval = self.rand(1,2,3,4);
bval = self.rand(2,3);
......@@ -2965,7 +2965,7 @@ class test_tensordot(unittest.TestCase):
atens = tensor4()
btens = tensor3()
axes = ((1,3),(0,2))
c = tensordot(axes)(atens, btens)
c = tensordot(atens, btens, axes)
f5 = inplace_func([atens,btens],c)
aval = self.rand(4,3,5,2);
bval = self.rand(3,4,2);
......@@ -2974,62 +2974,74 @@ class test_tensordot(unittest.TestCase):
utt.verify_grad(TensorDot(axes), [aval,bval])
axes = (axes[1],axes[0])
c = tensordot(axes)(btens, atens)
c = tensordot(btens, atens, axes)
f6 = inplace_func([btens,atens],c)
self.failUnless(numpy.allclose(numpy.tensordot(bval,aval,axes),
f6(bval,aval)))
utt.verify_grad(TensorDot(axes), [bval,aval])
def test_raise_error(self):
# test vector-vector
avec = vector()
amat = matrix()
bmat = matrix()
bvec = vector()
axes = ((0,),())
# test invalid length for axes
try:
c = tensordot(axes)(avec, bvec)
c = tensordot(amat, bmat, (0,1,2))
assert False
except ValueError:
pass
# test matrix-vector
bmat = matrix()
axes = ((0,),())
# test axes of uneven length
try:
c = tensordot(axes)(avec, bmat)
c = tensordot(amat, bmat, ((0,1),(0)))
assert False
except ValueError:
pass
# test matrix-matrix
amat = matrix()
axes = ((1,),())
# test invalid len(axes) given inputs are matrices
try:
c = tensordot(axes)(amat, bmat)
c = tensordot(amat, bmat, ((0,1,2),(0,1,2)))
assert False
except ValueError:
pass
def test_list(self):
# test invalid axes[1] given that y is a vector
try:
c = tensordot(amat, bvec, (0,1))
assert False
except ValueError:
pass
# test invalid scalar axes given inputs are matrices
try:
c = tensordot(amat, bvec, 2)
assert False
except ValueError:
pass
def test_weird_valid_axes(self):
# test matrix-matrix
amat = matrix()
bmat = matrix()
axes = [[1,],[0,]]
c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c)
aval = self.rand(4,7);
bval = self.rand(7,9);
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
def test_scalar(self):
for axes in 0, (1,0), [1,0], (1,(0,)), ((1,),0), ([1],[0]):
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
aval = self.rand(4,7);
bval = self.rand(7,9);
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
def test_scalar_axes(self):
# test matrix-matrix
amat = fmatrix()
bmat = dmatrix()#we let at float64 to test mix of float32 and float64.
axes = 1
aval = self.rand(4,5)
bval = numpy.random.rand(5,3)
c = tensordot(axes)(amat, bmat)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
......@@ -3041,7 +3053,7 @@ class test_tensordot(unittest.TestCase):
axes = 2
aval = self.rand(3,4,5)
bval = self.rand(4,5,3)
c = tensordot(axes)(amat, bmat)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
......@@ -3054,7 +3066,7 @@ class test_tensordot(unittest.TestCase):
axes = 0
aval = self.rand(4,5)
bval = self.rand(5,4)
c = tensordot(axes)(amat, bmat)
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论