提交 8c270d71 authored 作者: Philippe  Hamel's avatar Philippe Hamel

Det Op 3

removed dependency to scipy and used numpy instead fixed some convention discrepencies
上级 d15d4666
......@@ -550,26 +550,27 @@ def diag(x):
class Det(Op):
"""Matrix determinant
Input should be a square matrix
:note: Requires scipy
TODO: move this op to another file that request scipy.
"""
def make_node(self, x):
x = as_tensor_variable(x)
o = theano.tensor.scalar(dtype=x.dtype)
return Apply(self, [x], [o])
def perform(self, node, (x,), (z, )):
try:
z[0] = numpy.asarray(scipy.linalg.det(x), dtype=x.dtype)
z[0] = numpy.asarray(numpy.linalg.det(x), dtype=x.dtype)
except Exception:
print 'Failed to compute determinant', x
raise
def grad(self, inputs, g_outputs):
gz, = g_outputs
x, = inputs
return [gz * self(x) * matrix_inverse(x).T]
def infer_shape(self, node, shapes):
return [()]
def __str__(self):
return "Det"
det = Det()
......
......@@ -9,14 +9,6 @@ from theano.tensor.tests.test_rop import break_op
from theano.tests import unittest_tools as utt
from theano import config
try:
import scipy
if V(scipy.__version__) < V('0.7'):
raise ImportError()
use_scipy = True
except ImportError:
use_scipy = False
# The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse,
......@@ -167,37 +159,27 @@ def test_rop_lop():
def test_det():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
assert scipy.linalg.det(r) == f(r)
assert numpy.linalg.det(r) == f(r)
def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random)
def test_det_shape():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
x = tensor.matrix()
f = theano.function([x],det(x))
f_shape = theano.function([x],det(x).shape)
assert numpy.all(f(r).shape == f_shape(r))
assert numpy.all(f(r).shape == f_shape(r))
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论