提交 749c474f authored 作者: nouiz's avatar nouiz

Merge pull request #158 from dwf/matrix_inv_take2

Matrix inv take2
from ops import (cholesky, matrix_inverse, solve,
diag, extract_diag, alloc_diag,
det, PSD_hint,
det, psd,
trace, spectral_radius_bound)
......@@ -164,8 +164,14 @@ class HintsOptimizer(Optimizer):
theano.compile.mode.optdb.register('HintsOpt', HintsOptimizer(), -1, 'fast_run', 'fast_compile')
def PSD_hint(v):
return Hint(psd=True,symmetric=True)(v)
def psd(v):
"""
Apply a hint that the variable `v` is positive semi-definite, i.e.
it is a symmetric matrix and x^T A x >= for any vector x.
"""
return Hint(psd=True, symmetric=True)(v)
def is_psd(v):
return hints(v).get('psd', False)
def is_symmetric(v):
......@@ -275,6 +281,12 @@ def local_log_pow(node):
def matrix_dot(*args):
""" Shorthand for product between several dots
Given :math:`N` matrices :math:`A_0, A_1, .., A_N`, ``matrix_dot`` will
generate the matrix product between all in the given order, namely
:math:`A_0 \cdot A_1 \cdot A_2 \cdot .. \cdot A_N`.
"""
rval = args[0]
for a in args[1:]:
rval = theano.tensor.dot(rval, a)
......@@ -329,32 +341,88 @@ class Cholesky(Op):
cholesky = Cholesky()
class MatrixInverse(Op):
"""Compute a matrix inverse"""
"""Computes the inverse of a matrix :math:`A`.
Given a square matrix :math:`A`, ``matrix_inverse`` returns a square
matrix :math:`A_{inv}` such that the dot product :math:`A \cdot A_{inv}`
and :math:`A_{inv} \cdot A` equals the identity matrix :math:`I`.
:note: When possible, the call to this op will be optimized to the call
of ``solve``.
"""
def __init__(self):
pass
def props(self):
"""Function exposing different properties of each instance of the
op.
For the ``MatrixInverse`` op, there are no properties to be exposed.
"""
return ()
def __hash__(self):
return hash((type(self), self.props()))
def __eq__(self, other):
return (type(self)==type(other) and self.props() == other.props())
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, (x,), (z, )):
try:
z[0] = numpy.linalg.inv(x).astype(x.dtype)
except Exception:
print 'Failed to invert', node.inputs[0]
except numpy.linalg.LinAlgError:
logger.debug('Failed to invert %s' % str(node.inputs[0]))
raise
def grad(self, inputs, g_outputs):
"""The gradient function should return:
:math:`V\\frac{\partial X^{-1}}{\partial X}`
where :math:`V` corresponds to ``g_outputs`` and :math:`X` to
``inputs``. Using the matrix cookbook
``http://www2.imm.dtu.dk/pubdb/views/publication_details.php?id=3274``,
once can deduce that the relation corresponds to :
:math:`(X^{-1} \cdot V^{T} \cdot X^{-1})^T`
"""
x, = inputs
xi = self(x)
gz, = g_outputs
#TT.dot(gz.T,xi)
return [-matrix_dot(xi,gz.T,xi).T]
def R_op(self, inputs, eval_points):
"""The gradient function should return:
:math:`\\frac{\partial X^{-1}}{\partial X}V`
where :math:`V` corresponds to ``g_outputs`` and :math:`X` to
``inputs``. Using the matrix cookbook
``http://www2.imm.dtu.dk/pubdb/views/publication_details.php?id=3274``,
once can deduce that the relation corresponds to :
:math:`X^{-1} \cdot V \cdot X^{-1}`
"""
x, = inputs
xi = self(x)
ev, = eval_points
if ev is None:
return [None]
#TT.dot(gz.T,xi)
return [-matrix_dot(xi,ev,xi)]
def __str__(self):
return "MatrixInverse"
matrix_inverse = MatrixInverse()
class Solve(Op):
......@@ -459,7 +527,7 @@ alloc_diag = AllocDiag()
def diag(x):
"""Numpy-compatibility method
If `x` is a matrix, return its diagonal.
If `x` is a vector return a matrix with it as its diagonal.
......
......@@ -5,11 +5,10 @@ import numpy
import theano
from theano import tensor, function
from theano.tensor.basic import _allclose
from theano.tensor.tests.test_rop import break_op
from theano.tests import unittest_tools as utt
from theano import config
utt.seed_rng()
try:
import scipy
if V(scipy.__version__) < V('0.7'):
......@@ -39,7 +38,7 @@ if 0:
def test_cholesky():
#TODO: test upper and lower triangular
#todo: unittest randomseed
rng = numpy.random.RandomState(1234)
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
......@@ -61,8 +60,7 @@ if 0:
def test_inverse_correctness():
#todo: unittest randomseed
rng = numpy.random.RandomState(12345)
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4,4).astype(theano.config.floatX)
......@@ -79,19 +77,81 @@ def test_inverse_correctness():
assert _allclose(numpy.identity(4), rir), rir
assert _allclose(numpy.identity(4), rri), rri
def test_inverse_singular():
singular = numpy.array([[1, 0, 0]] + [[0, 1, 0]] * 2,
dtype=theano.config.floatX)
a = tensor.matrix()
f = function([a], matrix_inverse(a))
try:
f(singular)
except numpy.linalg.LinAlgError:
return
assert False
def test_inverse_grad():
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4, 4)
tensor.verify_grad(matrix_inverse, [r], rng=numpy.random)
rng = numpy.random.RandomState(1234)
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4,4)
tensor.verify_grad(matrix_inverse, [r], rng=numpy.random)
def test_rop_lop():
mx = tensor.matrix('mx')
mv = tensor.matrix('mv')
v = tensor.vector('v')
y = matrix_inverse(mx).sum(axis=0)
yv = tensor.Rop(y, mx, mv)
rop_f = function([mx, mv], yv)
sy, _ = theano.scan( lambda i,y,x,v: (tensor.grad(y[i],x)*v).sum(),
sequences = tensor.arange(y.shape[0]),
non_sequences = [y,mx,mv])
scan_f = function([mx,mv], sy)
rng = numpy.random.RandomState(utt.fetch_seed())
vx = numpy.asarray(rng.randn(4,4), theano.config.floatX)
vv = numpy.asarray(rng.randn(4,4), theano.config.floatX)
v1 = rop_f(vx,vv)
v2 = scan_f(vx,vv)
assert _allclose(v1, v2), ('ROP mismatch: %s %s' % (v1, v2))
raised = False
try:
tmp = tensor.Rop(theano.clone(y,
replace={mx:break_op(mx)}), mx, mv)
except ValueError:
raised = True
if not raised:
raise Exception((
'Op did not raised an error even though the function'
' is not differentiable'))
vv = numpy.asarray(rng.uniform(size=(4,)), theano.config.floatX)
yv = tensor.Lop(y, mx, v)
lop_f = function([mx, v], yv)
sy = tensor.grad((v*y).sum(), mx)
scan_f = function([mx, v], sy)
v1 = lop_f(vx, vv)
v2 = scan_f(vx, vv)
assert _allclose(v1, v2), ('LOP mismatch: %s %s' % (v1, v2))
def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it.
if not use_scipy:
raise SkipTest('Scipy is not available')
rng = numpy.random.RandomState(1234)
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random)
......@@ -152,4 +212,4 @@ def test_trace():
ok = True
assert ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论