提交 6080bef5 authored 作者: James Bergstra's avatar James Bergstra

did a little work on sandbox.solve to clean it up

上级 3b831aab
import numpy import numpy, scipy.linalg
from theano import gof, tensor from theano import gof, tensor, scalar
import unittest import unittest
class Solve(gof.Op): class Solve(gof.Op):
""" """
Find the solution to the linear equation Ax=b, Find the solution to the linear equation Ax=b,
...@@ -9,25 +10,43 @@ class Solve(gof.Op): ...@@ -9,25 +10,43 @@ class Solve(gof.Op):
It use numpy.solve to find the solution. It use numpy.solve to find the solution.
""" """
def make_node(self, A, b): #TODO: Add class options to use the performance-enhancing flags
if not isinstance(A, gof.Variable) or not A.type==tensor.matrix().type: # sym_pos, lower, overwrite_a, overwrite_b
raise TypeError("We expected that A had a matrix type")
if not isinstance(B, gof.Variable) or not B.type==tensor.matrix().type:
raise TypeError("We expected that B had a matrix type")
node = gof.Apply(op=self, inputs=[A, B], outputs=[tensor.matrix()]) #TODO: Add C code that calls the underlying LAPACK routines
return node # and keeps a memory workspace from call to call as a non-default Op output
def perform(self, node, (A, B), (output, )): def __eq__(self, other):
ret=numpy.solve(A,B) return type(self) == type(other)
output[0]=ret
def grad(self, (theta, A, B), (gtheta,)): def __hash__(self):
raise NotImplementedError() return hash(type(self))
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
b_ = tensor.as_tensor_variable(b)
if A_.broadcastable != (False, False):
raise TypeError("A must be a matrix", A_.type)
if b_.broadcastable not in ((False,), (True, False), (False, False)):
raise TypeError("b must be a matrix or vector", b_.type)
odtype = scalar.upcast(A_.dtype, b_.dtype)
otype = tensor.TensorType(broadcastable=b_.broadcastable, dtype=odtype)
return gof.Apply(op=self, inputs=[A, B], outputs=[otype()])
def perform(self, node, (A, b), (output, )):
ret=scipy.linalg.solve(A,b)
if ret.dtype != node.outputs[0].dtype:
print >> sys.stderr, "WARNING: Solve.perform() required cast."
ret = theano._asarray(ret, dtype=node.outputs[0].dtype)
output[0]=ret
solve = Solve() solve = Solve()
## TODO: test dtype conversion
## TODO: test that invalid types are rejected by make_node
## TODO: test that each valid type for A and b works correctly
from theano.tests import unittest_tools as utt
class T_solve(unittest.TestCase): class T_solve(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed(666)) self.rng = numpy.random.RandomState(utt.fetch_seed(666))
...@@ -35,7 +54,7 @@ class T_solve(unittest.TestCase): ...@@ -35,7 +54,7 @@ class T_solve(unittest.TestCase):
def test0(self): def test0(self):
A=self.rng.randn(5,5) A=self.rng.randn(5,5)
b=numpy.array(range(5),dtype=float) b=numpy.array(range(5),dtype=float)
x=numpy.linalg.solve(A,b) x=scipy.linalg.solve(A,b)
Ax = numpy.dot(A,x) Ax = numpy.dot(A,x)
are = tensor.numeric_grad.abs_rel_err(Ax, b) are = tensor.numeric_grad.abs_rel_err(Ax, b)
self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b)) self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论