提交 1442f517 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a bit more documentation on the matrix inverse op.

上级 b87e163f
...@@ -329,32 +329,54 @@ class Cholesky(Op): ...@@ -329,32 +329,54 @@ class Cholesky(Op):
cholesky = Cholesky() cholesky = Cholesky()
class MatrixInverse(Op): class MatrixInverse(Op):
"""Compute a matrix inverse""" """Computes the inverse of a matrix :math:`A`.
Given a square matrix :math:`A`, ``matrix_inverse`` returns a square
matrix :math:`A_{inv}` such that the dot product :math:`A \cdot A_{inv}`
and :math:`A_{inv} \cdot A` equals the identity matrix :math:`I`.
:note: When possible, the call to this op will be optimized to the call
of ``solve``.
"""
def __init__(self): def __init__(self):
pass pass
def props(self): def props(self):
"""Function exposing different properties of each instance of the
op.
For the ``MatrixInverse`` op, there are no properties to be exposed.
"""
return () return ()
def __hash__(self): def __hash__(self):
return hash((type(self), self.props())) return hash((type(self), self.props()))
def __eq__(self, other): def __eq__(self, other):
return (type(self)==type(other) and self.props() == other.props()) return (type(self)==type(other) and self.props() == other.props())
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def perform(self, node, (x,), (z, )): def perform(self, node, (x,), (z, )):
try: try:
z[0] = numpy.linalg.inv(x).astype(x.dtype) z[0] = numpy.linalg.inv(x).astype(x.dtype)
except Exception: except Exception:
print 'Failed to invert', node.inputs[0] print 'Failed to invert', node.inputs[0]
raise raise
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
x, = inputs x, = inputs
xi = self(x) xi = self(x)
gz, = g_outputs gz, = g_outputs
#TT.dot(gz.T,xi) #TT.dot(gz.T,xi)
return [-matrix_dot(xi,gz.T,xi).T] return [-matrix_dot(xi,gz.T,xi).T]
def __str__(self): def __str__(self):
return "MatrixInverse" return "MatrixInverse"
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
class Solve(Op): class Solve(Op):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论