提交 ccc259cd authored 作者: Frederic's avatar Frederic

Update new kron doc, and add a test with numpy.kron.

上级 149326f5
...@@ -4,7 +4,11 @@ from theano import tensor ...@@ -4,7 +4,11 @@ from theano import tensor
def kron(a, b): def kron(a, b):
""" Kronecker product """ Kronecker product
Same as scipy.linalg.kron(a, b) or numpy.kron(a, b) Same as scipy.linalg.kron(a, b).
:note: numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.
:param a: array_like :param a: array_like
:param b: array_like :param b: array_like
......
...@@ -14,9 +14,6 @@ try: ...@@ -14,9 +14,6 @@ try:
except ImportError: except ImportError:
imported_scipy = False imported_scipy = False
if not imported_scipy:
raise SkipTest('Kron Op need the scipy package to be installed')
class TestKron(utt.InferShapeTester): class TestKron(utt.InferShapeTester):
...@@ -27,6 +24,9 @@ class TestKron(utt.InferShapeTester): ...@@ -27,6 +24,9 @@ class TestKron(utt.InferShapeTester):
self.op = kron self.op = kron
def test_perform(self): def test_perform(self):
if not imported_scipy:
raise SkipTest('kron tests need the scipy package to be installed')
for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]: for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]: for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
if len(shp0) + len(shp1) == 2: if len(shp0) + len(shp1) == 2:
...@@ -41,6 +41,21 @@ class TestKron(utt.InferShapeTester): ...@@ -41,6 +41,21 @@ class TestKron(utt.InferShapeTester):
out = f(a, b) out = f(a, b)
assert numpy.allclose(out, scipy.linalg.kron(a, b)) assert numpy.allclose(out, scipy.linalg.kron(a, b))
def test_numpy_2d(self):
for shp0 in [(2, 3)]:
for shp1 in [(6, 7)]:
if len(shp0) + len(shp1) == 2:
continue
x = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp0))
y = tensor.tensor(dtype='floatX',
broadcastable=(False,) * len(shp1))
f = function([x, y], kron(x, y))
a = numpy.asarray(self.rng.rand(*shp0)).astype(floatX)
b = self.rng.rand(*shp1).astype(floatX)
out = f(a, b)
assert numpy.allclose(out, numpy.kron(a, b))
if __name__ == "__main__": if __name__ == "__main__":
t = TestKron('setUp') t = TestKron('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论