提交 3a3f9342 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5210 from tfjgeorge/solve

Solve
...@@ -20,3 +20,8 @@ API ...@@ -20,3 +20,8 @@ API
.. automodule:: theano.tensor.slinalg .. automodule:: theano.tensor.slinalg
:members: :members:
:exclude-members: solve, solve_lower_triangular, solve_upper_triangular
.. autofunction:: solve(a, b)
.. autofunction:: solve_lower_triangular(a, b)
.. autofunction:: solve_upper_triangular(a, b)
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import warnings
from theano.tensor.slinalg import solve # noqa
import unittest message = ("The module theano.sandbox.solve will soon be deprecated.\n"
import sys "Please use tensor.slinalg.solve instead.")
import numpy warnings.warn(message)
import scipy.linalg
import theano
from theano import gof, tensor, scalar
from theano.tests import unittest_tools as utt
class Solve(gof.Op):
"""
Find the solution to the linear equation Ax=b.
A is a 2d matrix and b is a 1d or 2d matrix.
It use numpy.solve to find the solution.
"""
# TODO: Add class options to use the performance-enhancing flags
# sym_pos, lower, overwrite_a, overwrite_b
# TODO: Add C code that calls the underlying LAPACK routines
# and keeps a memory workspace from call to call as a non-default Op
# output
__props__ = ()
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
b_ = tensor.as_tensor_variable(b)
if A_.broadcastable != (False, False):
raise TypeError("A must be a matrix", A_.type)
if b_.broadcastable not in ((False,), (True, False), (False, False)):
raise TypeError("b must be a matrix or vector", b_.type)
odtype = scalar.upcast(A_.dtype, b_.dtype)
otype = tensor.TensorType(broadcastable=b_.broadcastable, dtype=odtype)
return gof.Apply(op=self, inputs=[A_, b_], outputs=[otype()])
def perform(self, node, inp, out):
A, b = inp
output, = out
ret = scipy.linalg.solve(A, b)
if ret.dtype != node.outputs[0].dtype:
print("WARNING: Solve.perform() required cast.", file=sys.stderr)
ret = theano._asarray(ret, dtype=node.outputs[0].dtype)
output[0] = ret
solve = Solve()
# TODO: test dtype conversion
# TODO: test that invalid types are rejected by make_node
# TODO: test that each valid type for A and b works correctly
class T_solve(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed(666))
def test0(self):
A = self.rng.randn(5, 5)
b = numpy.arange(5, dtype=float)
x = scipy.linalg.solve(A, b)
Ax = numpy.dot(A, x)
are = tensor.numeric_grad.abs_rel_err(Ax, b)
self.assertTrue(numpy.all(are < 1.0e-5), (are, Ax, b))
# print A,b
# print numpy.dot(A,x)
...@@ -202,10 +202,16 @@ class Solve(Op): ...@@ -202,10 +202,16 @@ class Solve(Op):
b = as_tensor_variable(b) b = as_tensor_variable(b)
assert A.ndim == 2 assert A.ndim == 2
assert b.ndim in [1, 2] assert b.ndim in [1, 2]
otype = tensor.tensor(
# infer dtype by solving the most simple
# case with (1, 1) matrices
o_dtype = scipy.linalg.solve(
numpy.eye(1).astype(A.dtype),
numpy.eye(1).astype(b.dtype)).dtype
x = tensor.tensor(
broadcastable=b.broadcastable, broadcastable=b.broadcastable,
dtype=(A * b).dtype) dtype=o_dtype)
return Apply(self, [A, b], [otype]) return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
...@@ -263,10 +269,31 @@ class Solve(Op): ...@@ -263,10 +269,31 @@ class Solve(Op):
A_bar = tensor.triu(A_bar) A_bar = tensor.triu(A_bar)
return [A_bar, b_bar] return [A_bar, b_bar]
solve = Solve() # general solve solve = Solve()
"""
Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
``b`` can be either a vector or a matrix.
Note
Parameters
----------
a : (M, M) symbolix matrix
A square matrix
b : (M,) or (M, N) symbolic vector or matrix
Right hand side matrix in ``a x = b``
Returns
-------
x : (M, ) or (M, N) symbolic vector or matrix
x will have the same shape as b
"""
# lower and upper triangular solves # lower and upper triangular solves
solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True) solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is lower triangular."""
solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False) solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is upper triangular."""
# TODO: Optimizations to replace multiplication by matrix inverse # TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten) # with solve() Op (still unwritten)
......
...@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal ...@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy import inf from numpy import inf
import itertools
import theano import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
...@@ -229,6 +231,27 @@ class test_Solve(utt.InferShapeTester): ...@@ -229,6 +231,27 @@ class test_Solve(utt.InferShapeTester):
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False), assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val)) upper_solve_func(U_val, b_val))
def test_solve_dtype(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
dtypes = ['uint8', 'uint16', 'uint32', 'uint64',
'int8', 'int16', 'int32', 'int64',
'float16', 'float32', 'float64']
A_val = numpy.eye(2)
b_val = numpy.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = tensor.matrix(dtype=A_dtype)
b = tensor.matrix(dtype=b_dtype)
x = solve(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng): def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical # ensure diagonal elements of A relatively large to avoid numerical
# precision issues # precision issues
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论