提交 b6a15d9d authored 作者: carriepl's avatar carriepl

Merge pull request #2237 from aalmah/jaberg-may21

fixing #1868 Misc accumulated fixes
...@@ -256,6 +256,18 @@ def is_positive(v): ...@@ -256,6 +256,18 @@ def is_positive(v):
return False return False
@register_canonicalize
@local_optimizer([DimShuffle])
def transinv_to_invtrans(node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
A, = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
X, = A.owner.inputs
return [A.owner.op(node.op(X))]
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @local_optimizer([Dot, Dot22])
def inv_as_solve(node): def inv_as_solve(node):
...@@ -272,6 +284,32 @@ def inv_as_solve(node): ...@@ -272,6 +284,32 @@ def inv_as_solve(node):
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(r.owner.inputs[0].T, l.T).T]
@register_stabilize
@register_canonicalize
@local_optimizer([Solve])
def tag_solve_triangular(node):
"""
If a general solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if node.op == solve:
if node.op.A_structure == 'general':
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, type(cholesky)):
if A.owner.op.lower:
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
if (A.owner and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
return [Solve('upper_triangular')(A, b)]
else:
return [Solve('lower_triangular')(A, b)]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose ...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose
from theano.tests.test_rop import break_op from theano.tests.test_rop import break_op
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config from theano import config
from theano.tensor.nlinalg import MatrixInverse
from theano.tensor import DimShuffle
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
pinv, pinv,
Solve, Solve,
solve,
diag, diag,
ExtractDiag, ExtractDiag,
extract_diag, extract_diag,
...@@ -137,3 +140,37 @@ def test_spectral_radius_bound(): ...@@ -137,3 +140,37 @@ def test_spectral_radius_bound():
except ValueError: except ValueError:
ok = True ok = True
assert ok assert ok
def test_transinv_to_invtrans():
X = tensor.matrix('X')
Y = tensor.nlinalg.matrix_inverse(X)
Z = Y.transpose()
f = theano.function([X], Z)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, MatrixInverse):
assert isinstance(node.inputs[0].owner.op, DimShuffle)
if isinstance(node.op, DimShuffle):
assert node.inputs[0].name == 'X'
def test_tag_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = tensor.matrix('A')
x = tensor.vector('x')
L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = theano.function([A,x], b1)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'lower_triangular'
f = theano.function([A,x], b2)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'upper_triangular'
...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
float * zoutdata = (float*)PyArray_DATA(%(Z)s); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s); const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
double * zoutdata = (double*) PyArray_DATA(%(Z)s); double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s); const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -154,8 +175,56 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -154,8 +175,56 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s; %(Z)s = %(A)s;
Py_INCREF(%(Z)s); Py_INCREF(%(Z)s);
} }
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
dims[1] = PyArray_DIMS(%(A)s)[1];
if ((dims[0] * dims[1]) < 100000)
{
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{
float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
} }
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
}
}
else
{ {
int Nz0 = PyArray_DIMS(%(Z)s)[0]; int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1]; int Nz1 = PyArray_DIMS(%(Z)s)[1];
...@@ -199,6 +268,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -199,6 +268,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)x_data, &Sx, (double*)x_data, &Sx,
(double*)y_data, &Sy, (double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1); (double*)(PyArray_DATA(%(Z)s)), &Sz1);
} }
else { else {
PyErr_SetString(PyExc_NotImplementedError, PyErr_SetString(PyExc_NotImplementedError,
...@@ -210,10 +281,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -210,10 +281,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0]; float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
sger_(&Nz1, &Nz0, &alpha, sger_(&Nz1, &Nz0, &alpha,
(float*)y_data, &Sy, (float*)y_data, &Sy,
(float*)x_data, &Sx, (float*)x_data, &Sx,
...@@ -242,6 +310,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -242,6 +310,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s %(fail)s
} }
} }
}
""" % locals() """ % locals()
...@@ -256,7 +325,7 @@ class CGer(BaseBLAS, Ger): ...@@ -256,7 +325,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (8, blas_header_version()) return (9, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
......
...@@ -111,6 +111,9 @@ class MatrixInverse(Op): ...@@ -111,6 +111,9 @@ class MatrixInverse(Op):
return [None] return [None]
return [-matrix_dot(xi, ev, xi)] return [-matrix_dot(xi, ev, xi)]
def infer_shape(self, node, shapes):
return shapes
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
......
...@@ -4309,6 +4309,11 @@ def local_log_add(node): ...@@ -4309,6 +4309,11 @@ def local_log_add(node):
z = node.inputs[0] z = node.inputs[0]
if z.owner and z.owner.op == T.add: if z.owner and z.owner.op == T.add:
zi = z.owner.inputs zi = z.owner.inputs
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
#raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp] if x.owner and x.owner.op == T.exp]
if len(pre_exp) == len(zi): if len(pre_exp) == len(zi):
......
...@@ -171,8 +171,15 @@ class Solve(Op): ...@@ -171,8 +171,15 @@ class Solve(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
#TODO: use the A_structure to go faster if self.A_structure == 'lower_triangular':
output_storage[0][0] = scipy.linalg.solve(A, b) rval = scipy.linalg.solve_triangular(
A, b, lower=True)
elif self.A_structure == 'upper_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=False)
else:
rval = scipy.linalg.solve(A, b)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b # computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
...@@ -61,13 +61,19 @@ def test_pseudoinverse_correctness(): ...@@ -61,13 +61,19 @@ def test_pseudoinverse_correctness():
assert _allclose(ri, numpy.linalg.pinv(r)) assert _allclose(ri, numpy.linalg.pinv(r))
def test_inverse_correctness(): class test_MatrixInverse(utt.InferShapeTester):
rng = numpy.random.RandomState(utt.fetch_seed()) def setUp(self):
super(test_MatrixInverse, self).setUp()
self.op_class = MatrixInverse
self.op = matrix_inverse
self.rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4, 4).astype(theano.config.floatX) def test_inverse_correctness(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
x = tensor.matrix() x = tensor.matrix()
xi = matrix_inverse(x) xi = self.op(x)
ri = function([x], xi)(r) ri = function([x], xi)(r)
assert ri.shape == r.shape assert ri.shape == r.shape
...@@ -79,6 +85,16 @@ def test_inverse_correctness(): ...@@ -79,6 +85,16 @@ def test_inverse_correctness():
assert _allclose(numpy.identity(4), rir), rir assert _allclose(numpy.identity(4), rir), rir
assert _allclose(numpy.identity(4), rri), rri assert _allclose(numpy.identity(4), rri), rri
def test_infer_shape(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
x = tensor.matrix()
xi = self.op(x)
self._compile_and_check([x], [xi], [r],
self.op_class, warn=False)
def test_matrix_dot(): def test_matrix_dot():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
......
...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester): ...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester):
dtype=config.floatX)], dtype=config.floatX)],
self.op_class, self.op_class,
warn=False) warn=False)
def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
y = self.op(A, b)
gen_solve_func = theano.function([A,b],y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = theano.function([L,b],y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = theano.function([U,b],y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
A_val = numpy.dot(A_val.transpose(), A_val)
assert numpy.allclose(scipy.linalg.solve(A_val, b_val),
gen_solve_func(A_val, b_val))
# 2-test lower traingular case
L_val = scipy.linalg.cholesky(A_val, lower=True)
assert numpy.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
lower_solve_func(L_val, b_val))
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论