提交 b6a15d9d authored 作者: carriepl's avatar carriepl

Merge pull request #2237 from aalmah/jaberg-may21

fixing #1868 Misc accumulated fixes
...@@ -256,6 +256,18 @@ def is_positive(v): ...@@ -256,6 +256,18 @@ def is_positive(v):
return False return False
@register_canonicalize
@local_optimizer([DimShuffle])
def transinv_to_invtrans(node):
if isinstance(node.op, DimShuffle):
if node.op.new_order == (1, 0):
A, = node.inputs
if A.owner:
if isinstance(A.owner.op, MatrixInverse):
X, = A.owner.inputs
return [A.owner.op(node.op(X))]
@register_stabilize @register_stabilize
@local_optimizer([Dot, Dot22]) @local_optimizer([Dot, Dot22])
def inv_as_solve(node): def inv_as_solve(node):
...@@ -272,6 +284,32 @@ def inv_as_solve(node): ...@@ -272,6 +284,32 @@ def inv_as_solve(node):
return [solve(r.owner.inputs[0].T, l.T).T] return [solve(r.owner.inputs[0].T, l.T).T]
@register_stabilize
@register_canonicalize
@local_optimizer([Solve])
def tag_solve_triangular(node):
"""
If a general solve() is applied to the output of a cholesky op, then
replace it with a triangular solve.
"""
if node.op == solve:
if node.op.A_structure == 'general':
A, b = node.inputs # result is solution Ax=b
if A.owner and isinstance(A.owner.op, type(cholesky)):
if A.owner.op.lower:
return [Solve('lower_triangular')(A, b)]
else:
return [Solve('upper_triangular')(A, b)]
if (A.owner and isinstance(A.owner.op, DimShuffle)
and A.owner.op.new_order == (1, 0)):
A_T, = A.owner.inputs
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
if A_T.owner.op.lower:
return [Solve('upper_triangular')(A, b)]
else:
return [Solve('lower_triangular')(A, b)]
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
......
...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose ...@@ -12,6 +12,8 @@ from theano.tensor.basic import _allclose
from theano.tests.test_rop import break_op from theano.tests.test_rop import break_op
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano import config from theano import config
from theano.tensor.nlinalg import MatrixInverse
from theano.tensor import DimShuffle
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -20,6 +22,7 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
pinv, pinv,
Solve, Solve,
solve,
diag, diag,
ExtractDiag, ExtractDiag,
extract_diag, extract_diag,
...@@ -137,3 +140,37 @@ def test_spectral_radius_bound(): ...@@ -137,3 +140,37 @@ def test_spectral_radius_bound():
except ValueError: except ValueError:
ok = True ok = True
assert ok assert ok
def test_transinv_to_invtrans():
X = tensor.matrix('X')
Y = tensor.nlinalg.matrix_inverse(X)
Z = Y.transpose()
f = theano.function([X], Z)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, MatrixInverse):
assert isinstance(node.inputs[0].owner.op, DimShuffle)
if isinstance(node.op, DimShuffle):
assert node.inputs[0].name == 'X'
def test_tag_solve_triangular():
cholesky_lower = Cholesky(lower=True)
cholesky_upper = Cholesky(lower=False)
A = tensor.matrix('A')
x = tensor.vector('x')
L = cholesky_lower(A)
U = cholesky_upper(A)
b1 = solve(L, x)
b2 = solve(U, x)
f = theano.function([A,x], b1)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'lower_triangular'
f = theano.function([A,x], b2)
for node in f.maker.fgraph.toposort():
if isinstance(node.op, Solve):
assert node.op.A_structure == 'upper_triangular'
...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -110,15 +110,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
float * zoutdata = (float*)PyArray_DATA(%(Z)s); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s); const float * zdata = (float*)PyArray_DATA(%(A)s);
const float * xdata = (float*)PyArray_DATA(%(x)s);
const float * ydata = (float*)PyArray_DATA(%(y)s);
const float * adata = (float*)PyArray_DATA(%(a)s);
const float alpha = adata[0];
float tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(float);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(float);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -126,15 +136,26 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ {
double * zoutdata = (double*) PyArray_DATA(%(Z)s); double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s); const double * zdata = (double*)PyArray_DATA(%(A)s);
const double * xdata = (double*)PyArray_DATA(%(x)s);
const double * ydata = (double*)PyArray_DATA(%(y)s);
const double * adata = (double*)PyArray_DATA(%(a)s);
const double alpha = adata[0];
double tmp, xx;
int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double); int Ai = PyArray_STRIDES(%(A)s)[0]/sizeof(double);
int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double); int Aj = PyArray_STRIDES(%(A)s)[1]/sizeof(double);
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double); int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i) for (int i = 0; i < dims[0]; ++i)
{ {
xx = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j) for (int j = 0; j < dims[1]; ++j)
{ {
zoutdata[Zi*i+Zj*j] = zdata[Ai*i+Aj*j]; tmp = zdata[Ai*i+Aj*j];
tmp += xx * ydata[yj * j];
zoutdata[Zi*i+Zj*j] = tmp;
} }
} }
} }
...@@ -154,93 +175,141 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -154,93 +175,141 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(Z)s = %(A)s; %(Z)s = %(A)s;
Py_INCREF(%(Z)s); Py_INCREF(%(Z)s);
} }
} npy_intp dims[2];
dims[0] = PyArray_DIMS(%(A)s)[0];
{ dims[1] = PyArray_DIMS(%(A)s)[1];
int Nz0 = PyArray_DIMS(%(Z)s)[0]; if ((dims[0] * dims[1]) < 100000)
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n"); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; const float * xdata = (float*)PyArray_DATA(%(x)s);
sger_(&Nz0, &Nz1, &alpha, const float * ydata = (float*)PyArray_DATA(%(y)s);
(float*)x_data, &Sx, const float * adata = (float*)PyArray_DATA(%(a)s);
(float*)y_data, &Sy, const float alpha = adata[0];
(float*)(PyArray_DATA(%(Z)s)), &Sz1); float tmp, axi;
int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(float);
int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(float);
int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(float);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(float);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double * zoutdata = (double*) PyArray_DATA(%(Z)s);
dger_(&Nz0, &Nz1, &alpha, const double * xdata = (double*)PyArray_DATA(%(x)s);
(double*)x_data, &Sx, const double * ydata = (double*)PyArray_DATA(%(y)s);
(double*)y_data, &Sy, const double * adata = (double*)PyArray_DATA(%(a)s);
(double*)(PyArray_DATA(%(Z)s)), &Sz1); const double alpha = adata[0];
} double tmp, axi;
else {
PyErr_SetString(PyExc_NotImplementedError, int Zi = PyArray_STRIDES(%(Z)s)[0]/sizeof(double);
"not float nor double"); int Zj = PyArray_STRIDES(%(Z)s)[1]/sizeof(double);
%(fail)s int xi = PyArray_STRIDES(%(x)s)[0]/sizeof(double);
int yj = PyArray_STRIDES(%(y)s)[0]/sizeof(double);
for (int i = 0; i < dims[0]; ++i)
{
axi = alpha * xdata[xi * i];
for (int j = 0; j < dims[1]; ++j)
{
zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j];
}
}
} }
} }
else if (PyArray_STRIDES(%(Z)s)[1] == elemsize) else
{ {
if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT) int Nz0 = PyArray_DIMS(%(Z)s)[0];
int Nz1 = PyArray_DIMS(%(Z)s)[1];
int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
int Sy = PyArray_STRIDES(%(y)s)[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES(%(Z)s)[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES(%(Z)s)[1] / elemsize) : (Nz0 + 1);
dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
dtype_%(y)s* y_data = (dtype_%(y)s*) PyArray_DATA(%(y)s);
// gemv expects pointers to the beginning of memory arrays,
// but numpy provides provides a pointer to the first element,
// so when the stride is negative, we need to get the last one.
if (Sx < 0)
x_data += (Nz0 - 1) * Sx;
if (Sy < 0)
y_data += (Nz1 - 1) * Sy;
if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1); if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0]; {
//fprintf(stderr, "alpha=%%f\\n", alpha); //fprintf(stderr, "A\\n");
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy); float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
sger_(&Nz1, &Nz0, &alpha, sger_(&Nz0, &Nz1, &alpha,
(float*)y_data, &Sy, (float*)x_data, &Sx,
(float*)x_data, &Sx, (float*)y_data, &Sy,
(float*)(PyArray_DATA(%(Z)s)), &Sz0); (float*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha,
(double*)x_data, &Sx,
(double*)y_data, &Sy,
(double*)(PyArray_DATA(%(Z)s)), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
} }
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE) else if (PyArray_STRIDES(%(Z)s)[1] == elemsize)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
dger_(&Nz1, &Nz0, &alpha, {
(double*)y_data, &Sy, float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
(double*)x_data, &Sx, sger_(&Nz1, &Nz0, &alpha,
(double*)(PyArray_DATA(%(Z)s)), &Sz0); (float*)y_data, &Sy,
(float*)x_data, &Sx,
(float*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha,
(double*)y_data, &Sy,
(double*)x_data, &Sx,
(double*)(PyArray_DATA(%(Z)s)), &Sz0);
}
else
{
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
} }
else else
{ {
PyErr_SetString(PyExc_NotImplementedError, PyErr_SetString(PyExc_AssertionError,
"not float nor double"); "A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s %(fail)s
} }
} }
else
{
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
}
} }
""" % locals() """ % locals()
...@@ -256,7 +325,7 @@ class CGer(BaseBLAS, Ger): ...@@ -256,7 +325,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (8, blas_header_version()) return (9, blas_header_version())
cger_inplace = CGer(True) cger_inplace = CGer(True)
cger_no_inplace = CGer(False) cger_no_inplace = CGer(False)
......
...@@ -111,6 +111,9 @@ class MatrixInverse(Op): ...@@ -111,6 +111,9 @@ class MatrixInverse(Op):
return [None] return [None]
return [-matrix_dot(xi, ev, xi)] return [-matrix_dot(xi, ev, xi)]
def infer_shape(self, node, shapes):
return shapes
matrix_inverse = MatrixInverse() matrix_inverse = MatrixInverse()
......
...@@ -4309,6 +4309,11 @@ def local_log_add(node): ...@@ -4309,6 +4309,11 @@ def local_log_add(node):
z = node.inputs[0] z = node.inputs[0]
if z.owner and z.owner.op == T.add: if z.owner and z.owner.op == T.add:
zi = z.owner.inputs zi = z.owner.inputs
if len(zi) != 2:
# -- upgrading Maximum to handle multiple inputs wasn't trivial
# TODO
#raise NotImplementedError()
return
pre_exp = [x.owner.inputs[0] for x in zi pre_exp = [x.owner.inputs[0] for x in zi
if x.owner and x.owner.op == T.exp] if x.owner and x.owner.op == T.exp]
if len(pre_exp) == len(zi): if len(pre_exp) == len(zi):
......
...@@ -171,9 +171,16 @@ class Solve(Op): ...@@ -171,9 +171,16 @@ class Solve(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
A, b = inputs A, b = inputs
#TODO: use the A_structure to go faster if self.A_structure == 'lower_triangular':
output_storage[0][0] = scipy.linalg.solve(A, b) rval = scipy.linalg.solve_triangular(
A, b, lower=True)
elif self.A_structure == 'upper_triangular':
rval = scipy.linalg.solve_triangular(
A, b, lower=False)
else:
rval = scipy.linalg.solve(A, b)
output_storage[0][0] = rval
# computes shape of x where x = inv(A) * b # computes shape of x where x = inv(A) * b
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
Ashape, Bshape = shapes Ashape, Bshape = shapes
......
...@@ -61,23 +61,39 @@ def test_pseudoinverse_correctness(): ...@@ -61,23 +61,39 @@ def test_pseudoinverse_correctness():
assert _allclose(ri, numpy.linalg.pinv(r)) assert _allclose(ri, numpy.linalg.pinv(r))
def test_inverse_correctness(): class test_MatrixInverse(utt.InferShapeTester):
rng = numpy.random.RandomState(utt.fetch_seed()) def setUp(self):
super(test_MatrixInverse, self).setUp()
self.op_class = MatrixInverse
self.op = matrix_inverse
self.rng = numpy.random.RandomState(utt.fetch_seed())
r = rng.randn(4, 4).astype(theano.config.floatX) def test_inverse_correctness(self):
x = tensor.matrix() r = self.rng.randn(4, 4).astype(theano.config.floatX)
xi = matrix_inverse(x)
ri = function([x], xi)(r) x = tensor.matrix()
assert ri.shape == r.shape xi = self.op(x)
assert ri.dtype == r.dtype
ri = function([x], xi)(r)
assert ri.shape == r.shape
assert ri.dtype == r.dtype
rir = numpy.dot(ri, r)
rri = numpy.dot(r, ri)
assert _allclose(numpy.identity(4), rir), rir
assert _allclose(numpy.identity(4), rri), rri
def test_infer_shape(self):
r = self.rng.randn(4, 4).astype(theano.config.floatX)
rir = numpy.dot(ri, r) x = tensor.matrix()
rri = numpy.dot(r, ri) xi = self.op(x)
assert _allclose(numpy.identity(4), rir), rir self._compile_and_check([x], [xi], [r],
assert _allclose(numpy.identity(4), rri), rri self.op_class, warn=False)
def test_matrix_dot(): def test_matrix_dot():
...@@ -490,4 +506,4 @@ class T_NormTests(unittest.TestCase): ...@@ -490,4 +506,4 @@ class T_NormTests(unittest.TestCase):
f = function([A[1][i]], norm(A[1][i], A[0][i])) f = function([A[1][i]], norm(A[1][i], A[0][i]))
t_n = f(A[2][i]) t_n = f(A[2][i])
n_n = numpy.linalg.norm(A[2][i], A[3][i]) n_n = numpy.linalg.norm(A[2][i], A[3][i])
assert _allclose(n_n, t_n) assert _allclose(n_n, t_n)
\ No newline at end of file
...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester): ...@@ -189,3 +189,41 @@ class test_Solve(utt.InferShapeTester):
dtype=config.floatX)], dtype=config.floatX)],
self.op_class, self.op_class,
warn=False) warn=False)
def test_solve_correctness(self):
if not imported_scipy:
raise SkipTest("Scipy needed for the Cholesky op.")
rng = numpy.random.RandomState(utt.fetch_seed())
A = theano.tensor.matrix()
b = theano.tensor.matrix()
y = self.op(A, b)
gen_solve_func = theano.function([A,b],y)
cholesky_lower = Cholesky(lower=True)
L = cholesky_lower(A)
y_lower = self.op(L, b)
lower_solve_func = theano.function([L,b],y_lower)
cholesky_upper = Cholesky(lower=False)
U = cholesky_upper(A)
y_upper = self.op(U, b)
upper_solve_func = theano.function([U,b],y_upper)
b_val = numpy.asarray(rng.rand(5, 1), dtype=config.floatX)
# 1-test general case
A_val = numpy.asarray(rng.rand(5, 5), dtype=config.floatX)
# positive definite matrix:
A_val = numpy.dot(A_val.transpose(), A_val)
assert numpy.allclose(scipy.linalg.solve(A_val, b_val),
gen_solve_func(A_val, b_val))
# 2-test lower traingular case
L_val = scipy.linalg.cholesky(A_val, lower=True)
assert numpy.allclose(scipy.linalg.solve_triangular(L_val, b_val, lower=True),
lower_solve_func(L_val, b_val))
# 3-test upper traingular case
U_val = scipy.linalg.cholesky(A_val, lower=False)
assert numpy.allclose(scipy.linalg.solve_triangular(U_val, b_val, lower=False),
upper_solve_func(U_val, b_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论