提交 7fd2e38f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

rop for matrix inverse plus test

上级 ab1c7862
......@@ -392,6 +392,28 @@ class MatrixInverse(Op):
#TT.dot(gz.T,xi)
return [-matrix_dot(xi,gz.T,xi).T]
def R_op(self, inputs, eval_points):
"""The gradient function should return:
:math:`\\frac{\partial X^{-1}}{\partial X}V`
where :math:`V` corresponds to ``g_outputs`` and :math:`X` to
``inputs``. Using the matrix cookbook
``http://www2.imm.dtu.dk/pubdb/views/publication_details.php?id=3274``,
once can deduce that the relation corresponds to :
:math:`X^{-1} \cdot V \cdot X^{-1}`
"""
x, = inputs
xi = self(x)
ev, = eval_points
if ev is None:
return [None]
#TT.dot(gz.T,xi)
return [-matrix_dot(xi,ev,xi)]
def __str__(self):
return "MatrixInverse"
......
......@@ -5,6 +5,7 @@ import numpy
import theano
from theano import tensor, function
from theano.tensor.basic import _allclose
from theano.tensor.tests.test_rop import break_op
from theano.tests import unittest_tools as utt
from theano import config
......@@ -81,11 +82,57 @@ def test_inverse_correctness():
def test_inverse_grad():
rng = numpy.random.RandomState(1234)
rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4,4)
tensor.verify_grad(matrix_inverse, [r], rng=numpy.random)
def test_rop_lop():
mx = tensor.matrix('mx')
mv = tensor.matrix('mv')
v = tensor.vector('v')
y = matrix_inverse(mx).sum(axis=0)
yv = tensor.Rop(y, mx, mv)
rop_f = function([mx, mv], yv)
sy, _ = theano.scan( lambda i,y,x,v: (tensor.grad(y[i],x)*v).sum(),
sequences = tensor.arange(y.shape[0]),
non_sequences = [y,mx,mv])
scan_f = function([mx,mv], sy)
rng = numpy.random.RandomState(utt.fetch_seed())
vx = numpy.asarray(rng.randn(4,4), theano.config.floatX)
vv = numpy.asarray(rng.randn(4,4), theano.config.floatX)
v1 = rop_f(vx,vv)
v2 = scan_f(vx,vv)
assert numpy.allclose(v1,v2), ('ROP mismatch: %s %s' % (v1, v2))
raised = False
try:
tmp = tensor.Rop(theano.clone(y,
replace={mx:break_op(mx)}), mx, mv)
except ValueError:
raised = True
if not raised:
raise Exception((
'Op did not raised an error even though the function'
' is not differentiable'))
vv = numpy.asarray(rng.uniform(size=(4,)), theano.config.floatX)
yv = tensor.Lop(y, mx, v)
lop_f = function([mx, v], yv)
sy = tensor.grad((v*y).sum(), mx)
scan_f = function([mx, v], sy)
v1 = lop_f(vx,vv)
v2 = scan_f(vx,vv)
assert numpy.allclose(v1,v2), ('LOP mismatch: %s %s' % (v1, v2))
def test_det_grad():
# If scipy is not available, this test will fail, thus we skip it.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论