提交 568ba109 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

adding infer_shape test for MatrixInverse Op

上级 36a6ef33
...@@ -61,13 +61,19 @@ def test_pseudoinverse_correctness(): ...@@ -61,13 +61,19 @@ def test_pseudoinverse_correctness():
assert _allclose(ri, numpy.linalg.pinv(r)) assert _allclose(ri, numpy.linalg.pinv(r))
def test_inverse_correctness(): class test_MatrixInverse(utt.InferShapeTester):
rng = numpy.random.RandomState(utt.fetch_seed()) def setUp(self):
super(test_MatrixInverse, self).setUp()
self.op_class = MatrixInverse
self.op = matrix_inverse
self.rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4, 4).astype(theano.config.floatX) def test_inverse_correctness(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
x = tensor.matrix() x = tensor.matrix()
xi = matrix_inverse(x) xi = self.op(x)
ri = function([x], xi)(r) ri = function([x], xi)(r)
assert ri.shape == r.shape assert ri.shape == r.shape
...@@ -79,6 +85,16 @@ def test_inverse_correctness(): ...@@ -79,6 +85,16 @@ def test_inverse_correctness():
assert _allclose(numpy.identity(4), rir), rir assert _allclose(numpy.identity(4), rir), rir
assert _allclose(numpy.identity(4), rri), rri assert _allclose(numpy.identity(4), rri), rri
def test_infer_shape(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
x = tensor.matrix()
xi = self.op(x)
self._compile_and_check([x], [xi], [r],
self.op_class, warn=False)
def test_matrix_dot(): def test_matrix_dot():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论