提交 3a3f9342 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5210 from tfjgeorge/solve

Solve
......@@ -20,3 +20,8 @@ API
.. automodule:: theano.tensor.slinalg
:members:
:exclude-members: solve, solve_lower_triangular, solve_upper_triangular
.. autofunction:: solve(a, b)
.. autofunction:: solve_lower_triangular(a, b)
.. autofunction:: solve_upper_triangular(a, b)
from __future__ import absolute_import, print_function, division
import warnings
from theano.tensor.slinalg import solve # noqa
import unittest
import sys
message = ("The module theano.sandbox.solve will soon be deprecated.\n"
"Please use tensor.slinalg.solve instead.")
import numpy
import scipy.linalg
import theano
from theano import gof, tensor, scalar
from theano.tests import unittest_tools as utt
class Solve(gof.Op):
"""
Find the solution to the linear equation Ax=b.
A is a 2d matrix and b is a 1d or 2d matrix.
It use numpy.solve to find the solution.
"""
# TODO: Add class options to use the performance-enhancing flags
# sym_pos, lower, overwrite_a, overwrite_b
# TODO: Add C code that calls the underlying LAPACK routines
# and keeps a memory workspace from call to call as a non-default Op
# output
__props__ = ()
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
b_ = tensor.as_tensor_variable(b)
if A_.broadcastable != (False, False):
raise TypeError("A must be a matrix", A_.type)
if b_.broadcastable not in ((False,), (True, False), (False, False)):
raise TypeError("b must be a matrix or vector", b_.type)
odtype = scalar.upcast(A_.dtype, b_.dtype)
otype = tensor.TensorType(broadcastable=b_.broadcastable, dtype=odtype)
return gof.Apply(op=self, inputs=[A_, b_], outputs=[otype()])
def perform(self, node, inp, out):
A, b = inp
output, = out
ret = scipy.linalg.solve(A, b)
if ret.dtype != node.outputs[0].dtype:
print("WARNING: Solve.perform() required cast.", file=sys.stderr)
ret = theano._asarray(ret, dtype=node.outputs[0].dtype)
output[0] = ret
solve = Solve()
# TODO: test dtype conversion
# TODO: test that invalid types are rejected by make_node
# TODO: test that each valid type for A and b works correctly
class T_solve(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed(666))
def test0(self):
A = self.rng.randn(5, 5)
b = numpy.arange(5, dtype=float)
x = scipy.linalg.solve(A, b)
Ax = numpy.dot(A, x)
are = tensor.numeric_grad.abs_rel_err(Ax, b)
self.assertTrue(numpy.all(are < 1.0e-5), (are, Ax, b))
# print A,b
# print numpy.dot(A,x)
warnings.warn(message)
......@@ -202,10 +202,16 @@ class Solve(Op):
b = as_tensor_variable(b)
assert A.ndim == 2
assert b.ndim in [1, 2]
otype = tensor.tensor(
# infer dtype by solving the most simple
# case with (1, 1) matrices
o_dtype = scipy.linalg.solve(
numpy.eye(1).astype(A.dtype),
numpy.eye(1).astype(b.dtype)).dtype
x = tensor.tensor(
broadcastable=b.broadcastable,
dtype=(A * b).dtype)
return Apply(self, [A, b], [otype])
dtype=o_dtype)
return Apply(self, [A, b], [x])
def perform(self, node, inputs, output_storage):
A, b = inputs
......@@ -263,10 +269,31 @@ class Solve(Op):
A_bar = tensor.triu(A_bar)
return [A_bar, b_bar]
solve = Solve() # general solve
solve = Solve()
"""
Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
``b`` can be either a vector or a matrix.
Note
Parameters
----------
a : (M, M) symbolix matrix
A square matrix
b : (M,) or (M, N) symbolic vector or matrix
Right hand side matrix in ``a x = b``
Returns
-------
x : (M, ) or (M, N) symbolic vector or matrix
x will have the same shape as b
"""
# lower and upper triangular solves
solve_lower_triangular = Solve(A_structure='lower_triangular', lower=True)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is lower triangular."""
solve_upper_triangular = Solve(A_structure='upper_triangular', lower=False)
"""Optimized implementation of :func:`theano.tensor.slinalg.solve` when A is upper triangular."""
# TODO: Optimizations to replace multiplication by matrix inverse
# with solve() Op (still unwritten)
......
......@@ -7,6 +7,8 @@ from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose
from numpy import inf
import itertools
import theano
from theano import tensor, function
from theano.tensor.basic import _allclose
......@@ -229,6 +231,27 @@ class test_Solve(utt.InferShapeTester):
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
def test_solve_dtype(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Solve op.")
dtypes = ['uint8', 'uint16', 'uint32', 'uint64',
'int8', 'int16', 'int32', 'int64',
'float16', 'float32', 'float64']
A_val = numpy.eye(2)
b_val = numpy.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = tensor.matrix(dtype=A_dtype)
b = tensor.matrix(dtype=b_dtype)
x = solve(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
assert x.dtype == x_result.dtype
def verify_solve_grad(self, m, n, A_structure, lower, rng):
# ensure diagonal elements of A relatively large to avoid numerical
# precision issues
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论