提交 b6a15d9d authored 作者: carriepl's avatar carriepl

Merge pull request #2237 from aalmah/jaberg-may21

fixing #1868 Misc accumulated fixes
...@@ -256,6 +256,18 @@ def is_positive(v): ...@@ -256,6 +256,18 @@ def is_positive(v):
return False return False
@register_canonicalize
@local_optimizer([DimShuffle])
def transinv_to_invtrans(node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
A, = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
X, = A.owner.inputs
return [A.owner.op(node.op(X))]
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @local_optimizer([Dot, Dot22])
def inv_as_solve(node): def inv_as_solve(node):
...@@ -272,6 +284,32 @@ def inv_as_solve(node): ...@@ -272,6 +284,32 @@ def inv_as_solve(node):
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(r.owner.inputs[0].T, l.T).T]
@register_stabilize
@register_canonicalize
@local_optimizer([Solve])
def tag_solve_triangular(node):
"""
If a general solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if node.op == solve:
if node.op.A_structure == 'general':
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, type(cholesky)):
if A.owner.op.lower:
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
if (A.owner and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
return [Solve('upper_triangular')(A, b)]
else:
return [Solve('lower_triangular')(A, b)]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose ...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose
from theano.tests.test_rop import break_op from theano.tests.test_rop import break_op
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config from theano import config
from theano.tensor.nlinalg import MatrixInverse
from theano.tensor import DimShuffle
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
pinv, pinv,
Solve, Solve,
solve,
diag, diag,
ExtractDiag, ExtractDiag,
extract_diag, extract_diag,
...@@ -137,3 +140,37 @@ def test_spectral_radius_bound(): ...@@ -137,3 +140,37 @@ def test_spectral_radius_bound():
except ValueError: except ValueError:
ok = True ok = True
assert ok assert ok
def test_transinv_to_invtrans():
X = tensor.matrix('X')
Y = tensor.nlinalg.matrix_inverse(X)
Z = Y.transpose()
f = theano.function([X], Z)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, MatrixInverse):
assert isinstance(node.inputs[0].owner.op, DimShuffle)
if isinstance(node.op, DimShuffle):
assert node.inputs[0].name == 'X'
def test_tag_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = tensor.matrix('A')
x = tensor.vector('x')
L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = theano.function([A,x], b1)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'lower_triangular'
f = theano.function([A,x], b2)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'upper_triangular'
差异被折叠。
...@@ -111,6 +111,9 @@ class MatrixInverse(Op): ...@@ -111,6 +111,9 @@ class MatrixInverse(Op):
return [None] return [None]
return [-matrix_dot(xi, ev, xi)] return [-matrix_dot(xi, ev, xi)]
def infer_shape(self, node, shapes):
return shapes
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
......
...@@ -4309,6 +4309,11 @@ def local_log_add(node): ...@@ -4309,6 +4309,11 @@ def local_log_add(node):
z = node.inputs[0] z = node.inputs[0]
if z.owner and z.owner.op == T.add: if z.owner and z.owner.op == T.add:
zi = z.owner.inputs zi = z.owner.inputs
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
#raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp] if x.owner and x.owner.op == T.exp]
if len(pre_exp) == len(zi): if len(pre_exp) == len(zi):
......
...@@ -171,9 +171,16 @@ class Solve(Op): ...@@ -171,9 +171,16 @@ class Solve(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
#TODO: use the A_structure to go faster if self.A_structure == 'lower_triangular':
output_storage[0][0] = scipy.linalg.solve(A, b) rval = scipy.linalg.solve_triangular(
A, b, lower=True)
elif self.A_structure == 'upper_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=False)
else:
rval = scipy.linalg.solve(A, b)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b # computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
Ashape, Bshape = shapes Ashape, Bshape = shapes
......
...@@ -61,23 +61,39 @@ def test_pseudoinverse_correctness(): ...@@ -61,23 +61,39 @@ def test_pseudoinverse_correctness():
assert _allclose(ri, numpy.linalg.pinv(r)) assert _allclose(ri, numpy.linalg.pinv(r))
def test_inverse_correctness(): class test_MatrixInverse(utt.InferShapeTester):
rng = numpy.random.RandomState(utt.fetch_seed()) def setUp(self):
super(test_MatrixInverse, self).setUp()
self.op_class = MatrixInverse
self.op = matrix_inverse
self.rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4, 4).astype(theano.config.floatX) def test_inverse_correctness(self):
x = tensor.matrix() r = self.rng.randn(4, 4).astype(theano.config.floatX)
xi = matrix_inverse(x)
ri = function([x], xi)(r) x = tensor.matrix()
assert ri.shape == r.shape xi = self.op(x)
assert ri.dtype == r.dtype
ri = function([x], xi)(r)
assert ri.shape == r.shape
assert ri.dtype == r.dtype
rir = numpy.dot(ri, r)
rri = numpy.dot(r, ri)
assert _allclose(numpy.identity(4), rir), rir
assert _allclose(numpy.identity(4), rri), rri
def test_infer_shape(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
rir = numpy.dot(ri, r) x = tensor.matrix()
rri = numpy.dot(r, ri) xi = self.op(x)
assert _allclose(numpy.identity(4), rir), rir self._compile_and_check([x], [xi], [r],
assert _allclose(numpy.identity(4), rri), rri self.op_class, warn=False)
def test_matrix_dot(): def test_matrix_dot():
...@@ -490,4 +506,4 @@ class T_NormTests(unittest.TestCase): ...@@ -490,4 +506,4 @@ class T_NormTests(unittest.TestCase):
f = function([A[1][i]], norm(A[1][i], A[0][i])) f = function([A[1][i]], norm(A[1][i], A[0][i]))
t_n = f(A[2][i]) t_n = f(A[2][i])
n_n = numpy.linalg.norm(A[2][i], A[3][i]) n_n = numpy.linalg.norm(A[2][i], A[3][i])
assert _allclose(n_n, t_n) assert _allclose(n_n, t_n)
\ No newline at end of file
...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester): ...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester):
dtype=config.floatX)], dtype=config.floatX)],
self.op_class, self.op_class,
warn=False) warn=False)
def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
y = self.op(A, b)
gen_solve_func = theano.function([A,b],y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = theano.function([L,b],y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = theano.function([U,b],y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
A_val = numpy.dot(A_val.transpose(), A_val)
assert numpy.allclose(scipy.linalg.solve(A_val, b_val),
gen_solve_func(A_val, b_val))
# 2-test lower traingular case
L_val = scipy.linalg.cholesky(A_val, lower=True)
assert numpy.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
lower_solve_func(L_val, b_val))
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论