提交 1cd88000 authored 作者: Frederic's avatar Frederic

pep8

上级 aa024e24
import numpy import numpy
import scipy.sparse
from theano import gof, tensor, scalar from theano import gof, tensor, scalar
from theano.tensor import blas from theano.tensor import blas
from theano.sparse.basic import ( from theano.sparse.basic import (
as_sparse_variable, SparseType, add_s_s, neg, as_sparse_variable, SparseType, add_s_s, neg,
mul_s_s, mul_s_d, mul_s_s, mul_s_d, dot,
CSMProperties, CSM, register_specialize, CSMProperties, CSM, register_specialize,
_is_sparse_variable, CSC, CSR, _is_sparse_variable, CSC, CSR,
csm_data, csm_indices, csm_indptr, csm_shape, csm_properties, csm_data, csm_indices, csm_indptr, csm_shape,
_is_sparse) _is_sparse)
...@@ -183,10 +184,18 @@ class MulSDCSC(gof.Op): ...@@ -183,10 +184,18 @@ class MulSDCSC(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(_b)s->nd != 2) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
...@@ -196,18 +205,21 @@ class MulSDCSC(gof.Op): ...@@ -196,18 +205,21 @@ class MulSDCSC(gof.Op):
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
...@@ -262,10 +274,18 @@ class MulSDCSR(gof.Op): ...@@ -262,10 +274,18 @@ class MulSDCSR(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (%(_b)s->nd != 2) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;}
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
...@@ -275,18 +295,21 @@ class MulSDCSR(gof.Op): ...@@ -275,18 +295,21 @@ class MulSDCSR(gof.Op):
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
...@@ -550,10 +573,22 @@ class StrucutedAddSVCSR(gof.Op): ...@@ -550,10 +573,22 @@ class StrucutedAddSVCSR(gof.Op):
raise NotImplementedError('Complex types are not supported for b') raise NotImplementedError('Complex types are not supported for b')
return """ return """
if (%(_b)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); %(fail)s;} if (%(_b)s->nd != 1) {
if (%(_data)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
if (%(_indices)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} %(fail)s;
if (%(_indptr)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} }
if (%(_data)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1");
%(fail)s;
}
if (%(_indices)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1");
%(fail)s;
}
if (%(_indptr)s->nd != 1) {
PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1");
%(fail)s;
}
if( %(_indices)s->descr->type_num != PyArray_INT32) { if( %(_indices)s->descr->type_num != PyArray_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
...@@ -563,18 +598,21 @@ class StrucutedAddSVCSR(gof.Op): ...@@ -563,18 +598,21 @@ class StrucutedAddSVCSR(gof.Op):
if (!%(_zout)s) if (!%(_zout)s)
{ {
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_b)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1,
%(_indices)s->dimensions, %(_b)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]) if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{ {
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it."); PyErr_SetString(PyExc_NotImplementedError,
"somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s; %(fail)s;
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
const npy_intp nnz = %(_indices)s->dimensions[0]; const npy_intp nnz = %(_indices)s->dimensions[0];
const npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this //TODO: error checking with this
const npy_intp N = %(_indptr)s->dimensions[0]-1;
const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data; const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)%(_data)s->data;
const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data; const npy_int32 * const __restrict__ indptr = (npy_int32 *)%(_indptr)s->data;
...@@ -671,6 +709,7 @@ class SamplingDot(gof.op.Op): ...@@ -671,6 +709,7 @@ class SamplingDot(gof.op.Op):
if not _is_sparse_variable(p): if not _is_sparse_variable(p):
raise TypeError(p) raise TypeError(p)
#TODO: use it.
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype, p.type.dtype)
return gof.Apply(self, [x, y, p], [p.type()]) return gof.Apply(self, [x, y, p], [p.type()])
...@@ -790,23 +829,35 @@ class SamplingDotCsr(gof.Op): ...@@ -790,23 +829,35 @@ class SamplingDotCsr(gof.Op):
[]).dtype_specs()[-1] []).dtype_specs()[-1]
rval = """ rval = """
if (%(x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} if (%(x)s->nd != 2) {
if (%(y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(y)s->nd != 2) {
PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(x)s->descr->type_num != %(typenum_x)s) { if (%(x)s->descr->type_num != %(typenum_x)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for x"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for x");
%(fail)s;}
if (%(y)s->descr->type_num != %(typenum_y)s) { if (%(y)s->descr->type_num != %(typenum_y)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for y"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for y");
%(fail)s;}
if (%(p_data)s->descr->type_num != %(typenum_p)s) { if (%(p_data)s->descr->type_num != %(typenum_p)s) {
PyErr_SetString(PyExc_NotImplementedError, "Invalid type for pattern"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"Invalid type for pattern");
%(fail)s;}
if (%(x)s->dimensions[1] != %(y)s->dimensions[1]) if (%(x)s->dimensions[1] != %(y)s->dimensions[1]) {
{PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed."); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError,
"x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed.");
%(fail)s;}
if (%(y)s->dimensions[0] != ((npy_int32 *)%(p_ncols)s->data)[0] || %(x)s->dimensions[0] != (%(p_ptr)s->dimensions[0] - 1)) if (%(y)s->dimensions[0] != ((npy_int32 *)%(p_ncols)s->data)[0] ||
{PyErr_SetString(PyExc_NotImplementedError, "The dimension of the pattern and the output must match"); %(fail)s;} %(x)s->dimensions[0] != (%(p_ptr)s->dimensions[0] - 1))
{PyErr_SetString(PyExc_NotImplementedError,
"The dimension of the pattern and the output must match"); %(fail)s;}
// Allocate output // Allocate output
if (!%(z_data)s if (!%(z_data)s
...@@ -815,7 +866,8 @@ class SamplingDotCsr(gof.Op): ...@@ -815,7 +866,8 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_data)s);} {Py_XDECREF(%(z_data)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_data)s->dimensions[0]; dims[0] = %(p_data)s->dimensions[0];
%(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zd)s); %(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zd)s);
} }
if (!%(z_ind)s if (!%(z_ind)s
|| (%(z_ind)s->dimensions[0] != %(p_ind)s->dimensions[0]) || (%(z_ind)s->dimensions[0] != %(p_ind)s->dimensions[0])
...@@ -823,7 +875,8 @@ class SamplingDotCsr(gof.Op): ...@@ -823,7 +875,8 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_ind)s);} {Py_XDECREF(%(z_ind)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_ind)s->dimensions[0]; dims[0] = %(p_ind)s->dimensions[0];
%(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zi)s); %(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zi)s);
} }
if (!%(z_ptr)s if (!%(z_ptr)s
|| (%(z_ptr)s->dimensions[0] != %(p_ptr)s->dimensions[0]) || (%(z_ptr)s->dimensions[0] != %(p_ptr)s->dimensions[0])
...@@ -831,7 +884,8 @@ class SamplingDotCsr(gof.Op): ...@@ -831,7 +884,8 @@ class SamplingDotCsr(gof.Op):
{Py_XDECREF(%(z_ptr)s);} {Py_XDECREF(%(z_ptr)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = %(p_ptr)s->dimensions[0]; dims[0] = %(p_ptr)s->dimensions[0];
%(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_zp)s); %(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims,
%(typenum_zp)s);
} }
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论