提交 08f7a7d3 authored 作者: Frederic Bastien's avatar Frederic Bastien

implement a new op that wrap around numpy.solve to solve linear equation system.

上级 dae9bef7
......@@ -5,6 +5,7 @@ from theano.scalar import *
import theano.tensor as T
import theano
from numpy import *
class Prepend_scalar_constant_to_each_row(Op):
def __init__(self, val = 0):
if isinstance(val, float):
......@@ -65,6 +66,30 @@ class Prepend_scalar_to_each_row(Op):
def grad(self, (val, mat), (goutput,)):
return goutput[:,0], goutput[:,1:]
class solve(Op):
"""
Find the solution to the linear equation Ax=b,
where A is a 2d matrix and b is a 1d or 2d matrix.
It use numpy.solve to find the solution.
"""
def make_node(self, A, b):
if not isinstance(A, Result) or not A.type==T.matrix().type:
raise TypeError("We expected that A had a matrix type")
if not isinstance(B, Result) or not B.type==T.matrix().type:
raise TypeError("We expected that B had a matrix type")
node = Apply(op=self, inputs=[A, B], outputs=[T.matrix()])
return node
def perform(self, node, (A, B), (output, )):
ret=numpy.solve(A,B)
output[0]=ret
def grad(self, (theta, A, B), (gtheta,)):
raise NotImplementedError()
if __name__ == '__main__':
x=T.matrix('x')
......@@ -80,3 +105,9 @@ if __name__ == '__main__':
mat=numpy.ones((3,5),dtype="float32")
print f(mat)
A=numpy.random.randn(5,5)
b=numpy.array(range(5),dtype=float)
x=linalg.solve(A,b)
print A,b
print numpy.dot(A,x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论