提交 a0286e23 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use random values in test.

上级 a06e92eb
...@@ -871,31 +871,30 @@ def test_shared_cudandarray(): ...@@ -871,31 +871,30 @@ def test_shared_cudandarray():
assert isinstance(a.type, tcn.CudaNdarrayType) assert isinstance(a.type, tcn.CudaNdarrayType)
def test_tensordot_reshape(): class test_tensordot_reshape(unittest.TestCase):
'''Test alternative tensordot implementation. '''Test alternative tensordot implementation.
Test that the tensordot implementation using dimshuffle, reshape and dot Test that the tensordot implementation using dimshuffle, reshape and dot
gives the same results as the default (numpy) version. gives the same results as the default (numpy) version.
''' '''
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test1(self):
# define some tensors # define some tensors
a = numpy.arange(20, dtype=theano.config.floatX) / 20.0 tensor1 = self.rng.rand(20, 10, 5, 8).astype(theano.config.floatX)
b = numpy.arange(10, dtype=theano.config.floatX) / 10.0 tensor2 = self.rng.rand(5, 8, 20).astype(theano.config.floatX)
c = numpy.arange(5, dtype=theano.config.floatX) / 5.0 tensor3 = self.rng.rand(8, 20, 5).astype(theano.config.floatX)
d = numpy.arange(8, dtype=theano.config.floatX) / 8.0
tensor1 = numpy.tensordot(
a,
numpy.tensordot(b, numpy.tensordot(c, d, 0), 0),
0)
tensor2 = numpy.tensordot(c, numpy.tensordot(d, a, 0), 0)
tensor3 = tensor2.swapaxes(1, 2).swapaxes(0, 2) # d, a, c
x = T.tensor4('x') x = T.tensor4('x')
y = T.tensor3('y') y = T.tensor3('y')
# case 1: number of axes to sum over # case 1: number of axes to sum over
default1 = theano.function([x, y], T.tensordot(x, y, 2))(tensor1, tensor2) default1 = theano.function([x, y], T.tensordot(x, y, 2))(
reshape1 = theano.function([x, y], B.tensordot(x, y, 2))(tensor1, tensor2) tensor1, tensor2)
reshape1 = theano.function([x, y], B.tensordot(x, y, 2))(
tensor1, tensor2)
assert numpy.allclose(default1, reshape1) assert numpy.allclose(default1, reshape1)
# case 2: axis pairs # case 2: axis pairs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论