提交 1cd88000 authored 作者: Frederic's avatar Frederic

pep8

上级 aa024e24
import numpy import numpy
import scipy.sparse
from theano import gof, tensor, scalar from theano import gof, tensor, scalar
from theano.tensor import blas from theano.tensor import blas
from theano.sparse.basic import ( from theano.sparse.basic import (
as_sparse_variable, SparseType, add_s_s, neg, as_sparse_variable, SparseType, add_s_s, neg,
mul_s_s, mul_s_d, mul_s_s, mul_s_d, dot,
CSMProperties, CSM, register_specialize, CSMProperties, CSM, register_specialize,
_is_sparse_variable, CSC, CSR, _is_sparse_variable, CSC, CSR,
csm_data, csm_indices, csm_indptr, csm_shape, csm_properties, csm_data, csm_indices, csm_indptr, csm_shape,
_is_sparse) _is_sparse)
...@@ -183,40 +184,51 @@ class MulSDCSC(gof.Op): ...@@ -183,40 +184,51 @@ class MulSDCSC(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(_b)s->nd != 2) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != PyArray_INT32) if( %(_indptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data; const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data;
dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data; dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data;
const npy_intp Sb = %(_b)s->strides[0]; const npy_intp Sb = %(_b)s->strides[0];
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_int32 j = 0; j < N; ++j)
{ {
...@@ -225,10 +237,10 @@ class MulSDCSC(gof.Op): ...@@ -225,10 +237,10 @@ class MulSDCSC(gof.Op):
{ {
// extract row index of non-null value // extract row index of non-null value
npy_int32 i = indices[i_idx]; npy_int32 i = indices[i_idx];
// extract i-th row of dense matrix // extract i-th row of dense matrix
const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * i); const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * i);
// write resulting gradient to sparse output // write resulting gradient to sparse output
zout[i_idx] = data[i_idx] * b_row[j]; zout[i_idx] = data[i_idx] * b_row[j];
} }
...@@ -262,52 +274,63 @@ class MulSDCSR(gof.Op): ...@@ -262,52 +274,63 @@ class MulSDCSR(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(_b)s->nd != 2) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != PyArray_INT32) if( %(_indptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data; const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data;
dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data; dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data;
const npy_intp Sb = %(_b)s->strides[0]; const npy_intp Sb = %(_b)s->strides[0];
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_int32 j = 0; j < N; ++j)
{ {
// extract i-th row of dense matrix // extract i-th row of dense matrix
const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * j); const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(%(_b)s->data + Sb * j);
// for each non-null value in the sparse column // for each non-null value in the sparse column
for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx)
{ {
// extract row index of non-null value // extract row index of non-null value
npy_int32 i = indices[i_idx]; npy_int32 i = indices[i_idx];
// write resulting gradient to sparse output // write resulting gradient to sparse output
zout[i_idx] = data[i_idx] * b_row[i]; zout[i_idx] = data[i_idx] * b_row[i];
} }
...@@ -550,42 +573,57 @@ class StrucutedAddSVCSR(gof.Op): ...@@ -550,42 +573,57 @@ class StrucutedAddSVCSR(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); %(fail)s;} if (%(_b)s->nd != 1) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} }
if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;
}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;
}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;
}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != PyArray_INT32) if( %(_indptr)s->descr->type_num != PyArray_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data; const npy_int32 * const __restrict__ indices = (npy_int32 *)%(_indices)s->data;
const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)%(_b)s->data; const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)%(_b)s->data;
dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data; dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)%(_zout)s->data;
const npy_intp Sb = %(_b)s->strides[0] / %(_b)s->descr->elsize; const npy_intp Sb = %(_b)s->strides[0] / %(_b)s->descr->elsize;
// loop over columns // loop over columns
for (npy_int32 j = 0; j < N; ++j) for (npy_int32 j = 0; j < N; ++j)
{ {
...@@ -594,7 +632,7 @@ class StrucutedAddSVCSR(gof.Op): ...@@ -594,7 +632,7 @@ class StrucutedAddSVCSR(gof.Op):
{ {
// extract row index of non-null value // extract row index of non-null value
npy_int32 i = indices[i_idx]; npy_int32 i = indices[i_idx];
// write resulting gradient to sparse output // write resulting gradient to sparse output
zout[i_idx] = data[i_idx] + Db[i * Sb]; zout[i_idx] = data[i_idx] + Db[i * Sb];
} }
...@@ -671,6 +709,7 @@ class SamplingDot(gof.op.Op): ...@@ -671,6 +709,7 @@ class SamplingDot(gof.op.Op):
if not _is_sparse_variable(p): if not _is_sparse_variable(p):
raise TypeError(p) raise TypeError(p)
#TODO: use it.
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype)
return gof.Apply(self, [x, y, p], [p.type()]) return gof.Apply(self, [x, y, p], [p.type()])
...@@ -790,24 +829,36 @@ class SamplingDotCsr(gof.Op): ...@@ -790,24 +829,36 @@ class SamplingDotCsr(gof.Op):
[]).dtype_specs()[-1] []).dtype_specs()[-1]
rval = """ rval = """
if (%(x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} if (%(x)s->nd != 2) {
if (%(y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(y)s->nd != 2) {
PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(x)s->descr->type_num != %(typenum_x)s) { if (%(x)s->descr->type_num != %(typenum_x)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for x"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for x");
%(fail)s;}
if (%(y)s->descr->type_num != %(typenum_y)s) { if (%(y)s->descr->type_num != %(typenum_y)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for y"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for y");
%(fail)s;}
if (%(p_data)s->descr->type_num != %(typenum_p)s) { if (%(p_data)s->descr->type_num != %(typenum_p)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for pattern"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for pattern");
if (%(x)s->dimensions[1] != %(y)s->dimensions[1]) %(fail)s;}
{PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed."); %(fail)s;}
if (%(x)s->dimensions[1] != %(y)s->dimensions[1]) {
if (%(y)s->dimensions[0] != ((npy_int32 *)%(p_ncols)s->data)[0] || %(x)s->dimensions[0] != (%(p_ptr)s->dimensions[0] - 1)) PyErr_SetString(PyExc_NotImplementedError,
{PyErr_SetString(PyExc_NotImplementedError, "The dimension of the pattern and the output must match"); %(fail)s;} "x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed.");
%(fail)s;}
if (%(y)s->dimensions[0] != ((npy_int32 *)%(p_ncols)s->data)[0] ||
%(x)s->dimensions[0] != (%(p_ptr)s->dimensions[0] - 1))
{PyErr_SetString(PyExc_NotImplementedError,
"The dimension of the pattern and the output must match"); %(fail)s;}
// Allocate output // Allocate output
if (!%(z_data)s if (!%(z_data)s
|| (%(z_data)s->dimensions[0] != %(p_data)s->dimensions[0]) || (%(z_data)s->dimensions[0] != %(p_data)s->dimensions[0])
...@@ -815,7 +866,8 @@ class SamplingDotCsr(gof.Op): ...@@ -815,7 +866,8 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_data)s);} {Py_XDECREF(%(z_data)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_data)s->dimensions[0]; dims[0] = %(p_data)s->dimensions[0];
%(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zd)s); %(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zd)s);
} }
if (!%(z_ind)s if (!%(z_ind)s
|| (%(z_ind)s->dimensions[0] != %(p_ind)s->dimensions[0]) || (%(z_ind)s->dimensions[0] != %(p_ind)s->dimensions[0])
...@@ -823,7 +875,8 @@ class SamplingDotCsr(gof.Op): ...@@ -823,7 +875,8 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_ind)s);} {Py_XDECREF(%(z_ind)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_ind)s->dimensions[0]; dims[0] = %(p_ind)s->dimensions[0];
%(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zi)s); %(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zi)s);
} }
if (!%(z_ptr)s if (!%(z_ptr)s
|| (%(z_ptr)s->dimensions[0] != %(p_ptr)s->dimensions[0]) || (%(z_ptr)s->dimensions[0] != %(p_ptr)s->dimensions[0])
...@@ -831,15 +884,16 @@ class SamplingDotCsr(gof.Op): ...@@ -831,15 +884,16 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_ptr)s);} {Py_XDECREF(%(z_ptr)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_ptr)s->dimensions[0]; dims[0] = %(p_ptr)s->dimensions[0];
%(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zp)s); %(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zp)s);
} }
{ {
// Product of MxK and NxK, output MxN // Product of MxK and NxK, output MxN
npy_intp M = %(x)s->dimensions[0]; npy_intp M = %(x)s->dimensions[0];
npy_intp N = %(y)s->dimensions[0]; npy_intp N = %(y)s->dimensions[0];
npy_intp K = %(y)s->dimensions[1]; npy_intp K = %(y)s->dimensions[1];
// pointers to access actual data in the arrays passed as params. // pointers to access actual data in the arrays passed as params.
const dtype_%(x)s* __restrict__ Dx = (dtype_%(x)s*)%(x)s->data; const dtype_%(x)s* __restrict__ Dx = (dtype_%(x)s*)%(x)s->data;
const dtype_%(y)s* __restrict__ Dy = (dtype_%(y)s*)%(y)s->data; const dtype_%(y)s* __restrict__ Dy = (dtype_%(y)s*)%(y)s->data;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论