提交 6080bef5 authored 作者: James Bergstra's avatar James Bergstra

did a little work on sandbox.solve to clean it up

上级 3b831aab
import numpy
from theano import gof, tensor
import numpy, scipy.linalg
from theano import gof, tensor, scalar
import unittest
class Solve(gof.Op):
"""
Find the solution to the linear equation Ax=b,
......@@ -9,25 +10,43 @@ class Solve(gof.Op):
It use numpy.solve to find the solution.
"""
def make_node(self, A, b):
if not isinstance(A, gof.Variable) or not A.type==tensor.matrix().type:
raise TypeError("We expected that A had a matrix type")
if not isinstance(B, gof.Variable) or not B.type==tensor.matrix().type:
raise TypeError("We expected that B had a matrix type")
#TODO: Add class options to use the performance-enhancing flags
# sym_pos, lower, overwrite_a, overwrite_b
node = gof.Apply(op=self, inputs=[A, B], outputs=[tensor.matrix()])
return node
#TODO: Add C code that calls the underlying LAPACK routines
# and keeps a memory workspace from call to call as a non-default Op output
def perform(self, node, (A, B), (output, )):
ret=numpy.solve(A,B)
output[0]=ret
def __eq__(self, other):
return type(self) == type(other)
def grad(self, (theta, A, B), (gtheta,)):
raise NotImplementedError()
def __hash__(self):
return hash(type(self))
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
b_ = tensor.as_tensor_variable(b)
if A_.broadcastable != (False, False):
raise TypeError("A must be a matrix", A_.type)
if b_.broadcastable not in ((False,), (True, False), (False, False)):
raise TypeError("b must be a matrix or vector", b_.type)
odtype = scalar.upcast(A_.dtype, b_.dtype)
otype = tensor.TensorType(broadcastable=b_.broadcastable, dtype=odtype)
return gof.Apply(op=self, inputs=[A, B], outputs=[otype()])
def perform(self, node, (A, b), (output, )):
ret=scipy.linalg.solve(A,b)
if ret.dtype != node.outputs[0].dtype:
print >> sys.stderr, "WARNING: Solve.perform() required cast."
ret = theano._asarray(ret, dtype=node.outputs[0].dtype)
output[0]=ret
solve = Solve()
## TODO: test dtype conversion
## TODO: test that invalid types are rejected by make_node
## TODO: test that each valid type for A and b works correctly
from theano.tests import unittest_tools as utt
class T_solve(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed(666))
......@@ -35,7 +54,7 @@ class T_solve(unittest.TestCase):
def test0(self):
A=self.rng.randn(5,5)
b=numpy.array(range(5),dtype=float)
x=numpy.linalg.solve(A,b)
x=scipy.linalg.solve(A,b)
Ax = numpy.dot(A,x)
are = tensor.numeric_grad.abs_rel_err(Ax, b)
self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论