提交 b6746643 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix TensorDot and TensorDotGrad with float32 and a mix of float32 and float64 inputs. Test it.

上级 4e0a1674
...@@ -3785,8 +3785,10 @@ class TensorDotGrad(Op): ...@@ -3785,8 +3785,10 @@ class TensorDotGrad(Op):
assert isinstance(x, Variable) assert isinstance(x, Variable)
assert isinstance(y, Variable) assert isinstance(y, Variable)
assert isinstance(gz, Variable) assert isinstance(gz, Variable)
gx = x.type() gx = tensor(dtype=scal.upcast(gz.dtype, y.dtype),
gy = y.type() broadcastable = x.broadcastable)
gy = tensor(dtype=scal.upcast(x.dtype, gz.dtype),
broadcastable = y.broadcastable)
op = self op = self
if isinstance(self.axes,int): if isinstance(self.axes,int):
axes = [range(x.ndim-self.axes,x.ndim),range(self.axes)] axes = [range(x.ndim-self.axes,x.ndim),range(self.axes)]
...@@ -3805,14 +3807,12 @@ class TensorDotGrad(Op): ...@@ -3805,14 +3807,12 @@ class TensorDotGrad(Op):
newshapex = numpy.zeros(x.ndim) newshapex = numpy.zeros(x.ndim)
newshapex[[newpos for newpos in idx]] = [i for i in range(x.ndim)] newshapex[[newpos for newpos in idx]] = [i for i in range(x.ndim)]
gx[0] = numpy.transpose(_gx, newshapex) gx[0] = numpy.transpose(_gx, newshapex)
assert str(gx[0].dtype) == 'float64'
_gy = numpy.tensordot(x, gz, [sum_over_x, range(x.ndim-len(self.axes[0]))]) _gy = numpy.tensordot(x, gz, [sum_over_x, range(x.ndim-len(self.axes[0]))])
idy = numpy.hstack((self.axes[1], sum_over_y)) idy = numpy.hstack((self.axes[1], sum_over_y))
newshapey = numpy.zeros(y.ndim) newshapey = numpy.zeros(y.ndim)
newshapey[[newpos for newpos in idy]] = [i for i in range(y.ndim)] newshapey[[newpos for newpos in idy]] = [i for i in range(y.ndim)]
gy[0] = numpy.transpose(_gy, newshapey) gy[0] = numpy.transpose(_gy, newshapey)
assert str(gy[0].dtype) == 'float64'
tensordot_grad = TensorDotGrad tensordot_grad = TensorDotGrad
...@@ -3859,13 +3859,13 @@ class TensorDot(Op): ...@@ -3859,13 +3859,13 @@ class TensorDot(Op):
axesdim, x.type.ndim, y.type.ndim) axesdim, x.type.ndim, y.type.ndim)
outdim = x.type.ndim + y.type.ndim - 2*axesdim outdim = x.type.ndim + y.type.ndim - 2*axesdim
output = tensor(dtype=x.dtype, broadcastable=[False]*outdim); output = tensor(dtype=scal.upcast(x.dtype, y.dtype),
broadcastable=[False]*outdim);
return Apply(op, inputs=[x,y], outputs=[output,]) return Apply(op, inputs=[x,y], outputs=[output,])
def perform(self, node, (x, y), (z,)): def perform(self, node, (x, y), (z,)):
try: try:
z[0] = numpy.asarray(numpy.tensordot(x, y, self.axes)) z[0] = numpy.asarray(numpy.tensordot(x, y, self.axes))
assert str(z[0].dtype) == 'float64'
except ValueError, e: except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that # The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape, self.axes) e.args = e.args + (x.shape, y.shape, self.axes)
......
...@@ -2447,61 +2447,64 @@ class test_tensordot(unittest.TestCase): ...@@ -2447,61 +2447,64 @@ class test_tensordot(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def rand(self, *shape):
return numpy.asarray(numpy.random.rand(*shape), dtype=config.floatX)
def test0(self): def test0(self):
# test vector-vector # test vector-vector
avec = dvector() avec = vector()
bvec = dvector() bvec = vector()
axes = ((0,),(0,)) axes = ((0,),(0,))
c = tensordot(axes)(avec, bvec) c = tensordot(axes)(avec, bvec)
f1 = inplace_func([avec,bvec],c) f1 = inplace_func([avec,bvec],c)
aval = numpy.random.rand(5); aval = self.rand(5);
bval = numpy.random.rand(5); bval = self.rand(5);
self.failUnless(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.tensordot(aval,bval,axes) == \
f1(aval,bval)) f1(aval,bval))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
# test matrix-vector # test matrix-vector
bmat = dmatrix() bmat = matrix()
axes = ((0,),(1,)) axes = ((0,),(1,))
c = tensordot(axes)(avec, bmat) c = tensordot(axes)(avec, bmat)
f2 = inplace_func([avec,bmat],c) f2 = inplace_func([avec,bmat],c)
aval = numpy.random.rand(5); aval = self.rand(5);
bval = numpy.random.rand(8,5); bval = self.rand(8,5);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f2(aval,bval))) f2(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
# test matrix-matrix # test matrix-matrix
amat = dmatrix() amat = matrix()
axes = ((1,),(0,)) axes = ((1,),(0,))
c = tensordot(axes)(amat, bmat) c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
aval = numpy.random.rand(4,7); aval = self.rand(4,7);
bval = numpy.random.rand(7,9); bval = self.rand(7,9);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f3(aval,bval))) f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
# test ndarray-matrix, sum over one dim of matrix # test ndarray-matrix, sum over one dim of matrix
atens = TensorType('float64', broadcastable=(False,)*4)() atens = tensor4()
axes = ((2,),(1,)) axes = ((2,),(1,))
c = tensordot(axes)(atens, bmat) c = tensordot(axes)(atens, bmat)
f4 = inplace_func([atens,bmat],c) f4 = inplace_func([atens,bmat],c)
aval = numpy.random.rand(1,2,3,4); aval = self.rand(1,2,3,4);
bval = numpy.random.rand(2,3); bval = self.rand(2,3);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f4(aval,bval))) f4(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
# test ndarray-ndarray # test ndarray-ndarray
atens = TensorType('float64', broadcastable=(False,)*4)() atens = tensor4()
btens = TensorType('float64', broadcastable=(False,)*3)() btens = tensor3()
axes = ((1,3),(0,2)) axes = ((1,3),(0,2))
c = tensordot(axes)(atens, btens) c = tensordot(axes)(atens, btens)
f5 = inplace_func([atens,btens],c) f5 = inplace_func([atens,btens],c)
aval = numpy.random.rand(4,3,5,2); aval = self.rand(4,3,5,2);
bval = numpy.random.rand(3,4,2); bval = self.rand(3,4,2);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f5(aval,bval))) f5(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
...@@ -2516,8 +2519,8 @@ class test_tensordot(unittest.TestCase): ...@@ -2516,8 +2519,8 @@ class test_tensordot(unittest.TestCase):
def test_raise_error(self): def test_raise_error(self):
# test vector-vector # test vector-vector
avec = dvector() avec = vector()
bvec = dvector() bvec = vector()
axes = ((0,),()) axes = ((0,),())
try: try:
c = tensordot(axes)(avec, bvec) c = tensordot(axes)(avec, bvec)
...@@ -2525,7 +2528,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2525,7 +2528,7 @@ class test_tensordot(unittest.TestCase):
except ValueError: except ValueError:
pass pass
# test matrix-vector # test matrix-vector
bmat = dmatrix() bmat = matrix()
axes = ((0,),()) axes = ((0,),())
try: try:
c = tensordot(axes)(avec, bmat) c = tensordot(axes)(avec, bmat)
...@@ -2534,7 +2537,7 @@ class test_tensordot(unittest.TestCase): ...@@ -2534,7 +2537,7 @@ class test_tensordot(unittest.TestCase):
pass pass
# test matrix-matrix # test matrix-matrix
amat = dmatrix() amat = matrix()
axes = ((1,),()) axes = ((1,),())
try: try:
c = tensordot(axes)(amat, bmat) c = tensordot(axes)(amat, bmat)
...@@ -2544,23 +2547,23 @@ class test_tensordot(unittest.TestCase): ...@@ -2544,23 +2547,23 @@ class test_tensordot(unittest.TestCase):
def test_list(self): def test_list(self):
# test matrix-matrix # test matrix-matrix
amat = dmatrix() amat = matrix()
bmat = dmatrix() bmat = matrix()
axes = [[1,],[0,]] axes = [[1,],[0,]]
c = tensordot(axes)(amat, bmat) c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
aval = numpy.random.rand(4,7); aval = self.rand(4,7);
bval = numpy.random.rand(7,9); bval = self.rand(7,9);
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f3(aval,bval))) f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
def test_scalar(self): def test_scalar(self):
# test matrix-matrix # test matrix-matrix
amat = dmatrix() amat = fmatrix()
bmat = dmatrix() bmat = dmatrix()#we let at float64 to test mix of float32 and float64.
axes = 1 axes = 1
aval = numpy.random.rand(4,5) aval = self.rand(4,5)
bval = numpy.random.rand(5,3) bval = numpy.random.rand(5,3)
c = tensordot(axes)(amat, bmat) c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
...@@ -2569,11 +2572,11 @@ class test_tensordot(unittest.TestCase): ...@@ -2569,11 +2572,11 @@ class test_tensordot(unittest.TestCase):
utt.verify_grad(TensorDot(axes), [aval,bval]) utt.verify_grad(TensorDot(axes), [aval,bval])
# test tensor-tensor # test tensor-tensor
amat = dtensor3() amat = tensor3()
bmat = dtensor3() bmat = tensor3()
axes = 2 axes = 2
aval = numpy.random.rand(3,4,5) aval = self.rand(3,4,5)
bval = numpy.random.rand(4,5,3) bval = self.rand(4,5,3)
c = tensordot(axes)(amat, bmat) c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c) f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \ self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
...@@ -2583,13 +2586,13 @@ class test_tensordot(unittest.TestCase): ...@@ -2583,13 +2586,13 @@ class test_tensordot(unittest.TestCase):
def test_tensordot_grad(self): def test_tensordot_grad(self):
#We test it manually as we recreate the op in the make_node #We test it manually as we recreate the op in the make_node
amat = dmatrix() amat = matrix()
bmat = dmatrix() bmat = matrix()
gzmat = dmatrix() gzmat = matrix()
axes = 1 axes = 1
aval = numpy.random.rand(4,5) aval = self.rand(4,5)
bval = numpy.random.rand(5,3) bval = self.rand(5,3)
gzval = numpy.random.rand(4,3) gzval = self.rand(4,3)
f1 = inplace_func([amat,bmat,gzmat],tensordot_grad(axes)(amat, bmat, gzmat)) f1 = inplace_func([amat,bmat,gzmat],tensordot_grad(axes)(amat, bmat, gzmat))
f2 = inplace_func([amat,bmat,gzmat],tensordot_grad(((1,),(0,)))(amat, bmat, gzmat)) f2 = inplace_func([amat,bmat,gzmat],tensordot_grad(((1,),(0,)))(amat, bmat, gzmat))
o1=f1(aval,bval,gzval) o1=f1(aval,bval,gzval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论