提交 ca8cb2c2 authored 作者: Thomas George's avatar Thomas George

Removed sandbox/solve.py as a solve Op that uses scipy can already be found in tensor/slinalg.py

上级 d1705c09
from __future__ import absolute_import, print_function, division
import unittest
import sys
import numpy
import scipy.linalg
import theano
from theano import gof, tensor, scalar
from theano.tests import unittest_tools as utt
class Solve(gof.Op):
"""
Find the solution to the linear equation Ax=b.
A is a 2d matrix and b is a 1d or 2d matrix.
It use numpy.solve to find the solution.
"""
# TODO: Add class options to use the performance-enhancing flags
# sym_pos, lower, overwrite_a, overwrite_b
# TODO: Add C code that calls the underlying LAPACK routines
# and keeps a memory workspace from call to call as a non-default Op
# output
__props__ = ()
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
b_ = tensor.as_tensor_variable(b)
if A_.broadcastable != (False, False):
raise TypeError("A must be a matrix", A_.type)
if b_.broadcastable not in ((False,), (True, False), (False, False)):
raise TypeError("b must be a matrix or vector", b_.type)
odtype = scalar.upcast(A_.dtype, b_.dtype)
otype = tensor.TensorType(broadcastable=b_.broadcastable, dtype=odtype)
return gof.Apply(op=self, inputs=[A_, b_], outputs=[otype()])
def perform(self, node, inp, out):
A, b = inp
output, = out
ret = scipy.linalg.solve(A, b)
if ret.dtype != node.outputs[0].dtype:
print("WARNING: Solve.perform() required cast.", file=sys.stderr)
ret = theano._asarray(ret, dtype=node.outputs[0].dtype)
output[0] = ret
solve = Solve()
# TODO: test dtype conversion
# TODO: test that invalid types are rejected by make_node
# TODO: test that each valid type for A and b works correctly
class T_solve(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed(666))
def test0(self):
A = self.rng.randn(5, 5)
b = numpy.arange(5, dtype=float)
x = scipy.linalg.solve(A, b)
Ax = numpy.dot(A, x)
are = tensor.numeric_grad.abs_rel_err(Ax, b)
self.assertTrue(numpy.all(are < 1.0e-5), (are, Ax, b))
# print A,b
# print numpy.dot(A,x)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论