提交 4e0a1674 authored 作者: Frederic Bastien's avatar Frederic Bastien

make TensorDot and TensorDotGrad work with scalar axes. Add tests for that and…

make TensorDot and TensorDotGrad work with scalar axes. Add tests for that and test that is raise error.
上级 5b5b21a6
......@@ -3787,7 +3787,11 @@ class TensorDotGrad(Op):
assert isinstance(gz, Variable)
gx = x.type()
gy = y.type()
return Apply(self, [x,y,gz], [gx, gy])
op = self
if isinstance(self.axes,int):
axes = [range(x.ndim-self.axes,x.ndim),range(self.axes)]
op = TensorDotGrad(axes)
return Apply(op, [x,y,gz], [gx, gy])
def perform(self, node, (x, y, gz), (gx,gy)):
......@@ -3841,8 +3845,13 @@ class TensorDot(Op):
return hashtype(self) ^ hash(self.axes) ^ 89234
def make_node(self, x, y):
op = self
if isinstance(self.axes,int):
axes = [range(x.ndim-self.axes,x.ndim),range(self.axes)]
op = TensorDot(axes)
axesdim = numpy.size(op.axes)/2
axesdim = numpy.size(self.axes)/2
x, y = map(as_tensor_variable, [x, y])
if axesdim > x.type.ndim or axesdim > y.type.ndim:
......@@ -3851,7 +3860,7 @@ class TensorDot(Op):
outdim = x.type.ndim + y.type.ndim - 2*axesdim
output = tensor(dtype=x.dtype, broadcastable=[False]*outdim);
return Apply(self, inputs=[x,y], outputs=[output,])
return Apply(op, inputs=[x,y], outputs=[output,])
def perform(self, node, (x, y), (z,)):
try:
......
......@@ -2513,6 +2513,35 @@ class test_tensordot(unittest.TestCase):
f6(bval,aval)))
utt.verify_grad(TensorDot(axes), [bval,aval])
def test_raise_error(self):
# test vector-vector
avec = dvector()
bvec = dvector()
axes = ((0,),())
try:
c = tensordot(axes)(avec, bvec)
assert False
except ValueError:
pass
# test matrix-vector
bmat = dmatrix()
axes = ((0,),())
try:
c = tensordot(axes)(avec, bmat)
assert False
except ValueError:
pass
# test matrix-matrix
amat = dmatrix()
axes = ((1,),())
try:
c = tensordot(axes)(amat, bmat)
assert False
except ValueError:
pass
def test_list(self):
# test matrix-matrix
amat = dmatrix()
......@@ -2526,6 +2555,47 @@ class test_tensordot(unittest.TestCase):
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
def test_scalar(self):
# test matrix-matrix
amat = dmatrix()
bmat = dmatrix()
axes = 1
aval = numpy.random.rand(4,5)
bval = numpy.random.rand(5,3)
c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
# test tensor-tensor
amat = dtensor3()
bmat = dtensor3()
axes = 2
aval = numpy.random.rand(3,4,5)
bval = numpy.random.rand(4,5,3)
c = tensordot(axes)(amat, bmat)
f3 = inplace_func([amat,bmat],c)
self.failUnless(numpy.all(numpy.tensordot(aval,bval,axes) == \
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
def test_tensordot_grad(self):
#We test it manually as we recreate the op in the make_node
amat = dmatrix()
bmat = dmatrix()
gzmat = dmatrix()
axes = 1
aval = numpy.random.rand(4,5)
bval = numpy.random.rand(5,3)
gzval = numpy.random.rand(4,3)
f1 = inplace_func([amat,bmat,gzmat],tensordot_grad(axes)(amat, bmat, gzmat))
f2 = inplace_func([amat,bmat,gzmat],tensordot_grad(((1,),(0,)))(amat, bmat, gzmat))
o1=f1(aval,bval,gzval)
o2=f2(aval,bval,gzval)
self.failUnless(numpy.all(o1[0]==o2[0]))
self.failUnless(numpy.all(o1[1]==o2[1]))
def test_smallest_stack():
sx, sy = dscalar(), dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论