提交 d15d4666 authored 作者: Philippe  Hamel's avatar Philippe Hamel

Det Op #2

fixed infer_shape and added dome tests
上级 e0fbb48b
...@@ -569,7 +569,7 @@ class Det(Op): ...@@ -569,7 +569,7 @@ class Det(Op):
x, = inputs x, = inputs
return [gz * self(x) * matrix_inverse(x).T] return [gz * self(x) * matrix_inverse(x).T]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [(1, )] return [()]
def __str__(self): def __str__(self):
return "Det" return "Det"
det = Det() det = Det()
......
...@@ -166,6 +166,17 @@ def test_rop_lop(): ...@@ -166,6 +166,17 @@ def test_rop_lop():
assert _allclose(v1, v2), ('LOP mismatch: %s %s' % (v1, v2)) assert _allclose(v1, v2), ('LOP mismatch: %s %s' % (v1, v2))
def test_det():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
assert scipy.linalg.det(r) == f(r)
def test_det_grad(): def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it. # If scipy is not available, this test will fail, thus we skip it.
if not use_scipy: if not use_scipy:
...@@ -174,6 +185,18 @@ def test_det_grad(): ...@@ -174,6 +185,18 @@ def test_det_grad():
r = rng.randn(5,5) r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random) tensor.verify_grad(det, [r], rng=numpy.random)
def test_det_shape():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
f_shape = theano.function([x],det(x).shape)
assert numpy.all(f(r).shape == f_shape(r))
def test_extract_diag(): def test_extract_diag():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论