提交 549d73cc authored 作者: benanne's avatar benanne

Test for the dimshuffle/reshape-based tensordot implementation, which compares…

Test for the dimshuffle/reshape-based tensordot implementation, which compares it against the traditional implementation
上级 2f4c5868
......@@ -870,6 +870,36 @@ def test_shared_cudandarray():
a = cuda.shared_constructor(cuda.CudaNdarray.zeros((2,3)))
assert isinstance(a.type, tcn.CudaNdarrayType)
def test_tensordot_reshape():
'''Test that the tensordot implementation using dimshuffle, reshape and dot
gives the same results as the default (numpy) version'''
# define some tensors
a = numpy.arange(20, dtype=theano.config.floatX) / 20.0
b = numpy.arange(10, dtype=theano.config.floatX) / 10.0
c = numpy.arange(5, dtype=theano.config.floatX) / 5.0
d = numpy.arange(8, dtype=theano.config.floatX) / 8.0
tensor1 = numpy.tensordot(a, numpy.tensordot(b, numpy.tensordot(c, d, 0), 0), 0)
tensor2 = numpy.tensordot(c, numpy.tensordot(d, a, 0), 0)
tensor3 = tensor2.swapaxes(1, 2).swapaxes(0, 2) # d, a, c
x = T.tensor4('x')
y = T.tensor3('y')
# case 1: number of axes to sum over
default1 = theano.function([x,y], T.tensordot(x, y, 2))(tensor1, tensor2)
reshape1 = theano.function([x,y], B.tensordot(x, y, 2))(tensor1, tensor2)
assert numpy.allclose(default1, reshape1)
# case 2: axis pairs
default2 = theano.function([x,y], T.tensordot(x, y, axes=[(0, 3), (1, 0)]))(tensor1, tensor3)
reshape2 = theano.function([x,y], B.tensordot(x, y, axes=[(0, 3), (1, 0)]))(tensor1, tensor3)
assert numpy.allclose(default2, reshape2)
default3 = theano.function([x,y], T.tensordot(x, y, axes=[(0, 3, 2), (1, 0, 2)]))(tensor1, tensor3)
reshape3 = theano.function([x,y], B.tensordot(x, y, axes=[(0, 3, 2), (1, 0, 2)]))(tensor1, tensor3)
assert numpy.allclose(default3, reshape3)
class test_size(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论