提交 0db3f9af authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3736 from lamblin/fix_kron_test

Update test for recent scipy.
......@@ -306,7 +306,14 @@ class TestKron(utt.InferShapeTester):
f = function([x, y], kron(x, y))
b = self.rng.rand(*shp1).astype(config.floatX)
out = f(a, b)
assert numpy.allclose(out, scipy.linalg.kron(a, b))
# Newer versions of scipy want 4 dimensions at least,
# so we have to add a dimension to a and flatten the result.
if len(shp0) + len(shp1) == 3:
scipy_val = scipy.linalg.kron(
a[numpy.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
utt.assert_allclose(out, scipy_val)
def test_numpy_2d(self):
for shp0 in [(2, 3)]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论