提交 b91ab02b authored 作者: Frederic's avatar Frederic

Speed up a test

上级 0ce1949f
......@@ -296,30 +296,30 @@ class TestKron(utt.InferShapeTester):
raise SkipTest('kron tests need the scipy package to be installed')
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2:
continue
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
y = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp1))
f = function([x, y], kron(x, y))
a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
b = self.rng.rand(*shp1).astype(config.floatX)
out = f(a, b)
assert numpy.allclose(out, scipy.linalg.kron(a, b))
def test_numpy_2d(self):
for shp0 in [(2, 3)]:
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
y = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp1))
f = function([x, y], kron(x, y))
a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
b = self.rng.rand(*shp1).astype(config.floatX)
out = f(a, b)
assert numpy.allclose(out, numpy.kron(a, b))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论