提交 bdee7c76 authored 作者: nouiz's avatar nouiz

Merge pull request #201 from hamelphi/det_test_and_doc/master

det OP
...@@ -548,24 +548,29 @@ def diag(x): ...@@ -548,24 +548,29 @@ def diag(x):
raise TypeError('diag requires vector or matrix argument', x) raise TypeError('diag requires vector or matrix argument', x)
class Det(Op): class Det(Op):
"""matrix determinant """Matrix determinant
Input should be a square matrix
TODO: move this op to another file that request scipy.
""" """
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
o = theano.tensor.scalar(dtype=x.dtype) o = theano.tensor.scalar(dtype=x.dtype)
return Apply(self, [x], [o]) return Apply(self, [x], [o])
def perform(self, node, (x,), (z, )): def perform(self, node, (x,), (z, )):
try: try:
z[0] = numpy.asarray(scipy.linalg.det(x), dtype=x.dtype) z[0] = numpy.asarray(numpy.linalg.det(x), dtype=x.dtype)
except Exception: except Exception:
print 'Failed to compute determinant', x print 'Failed to compute determinant', x
raise raise
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
gz, = g_outputs gz, = g_outputs
x, = inputs x, = inputs
return [gz * self(x) * matrix_inverse(x).T] return [gz * self(x) * matrix_inverse(x).T]
def infer_shape(self, node, shapes):
return [()]
def __str__(self): def __str__(self):
return "Det" return "Det"
det = Det() det = Det()
......
...@@ -9,14 +9,6 @@ from theano.tensor.tests.test_rop import break_op ...@@ -9,14 +9,6 @@ from theano.tensor.tests.test_rop import break_op
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config from theano import config
try:
import scipy
if V(scipy.__version__) < V('0.7'):
raise ImportError()
use_scipy = True
except ImportError:
use_scipy = False
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
...@@ -166,15 +158,28 @@ def test_rop_lop(): ...@@ -166,15 +158,28 @@ def test_rop_lop():
assert _allclose(v1, v2), ('LOP mismatch: %s %s' % (v1, v2)) assert _allclose(v1, v2), ('LOP mismatch: %s %s' % (v1, v2))
def test_det():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
assert numpy.linalg.det(r) == f(r)
def test_det_grad(): def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5) r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random) tensor.verify_grad(det, [r], rng=numpy.random)
def test_det_shape():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
f_shape = theano.function([x],det(x).shape)
assert numpy.all(f(r).shape == f_shape(r))
def test_extract_diag(): def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论