提交 a924f75b authored 作者: Frederic's avatar Frederic

use PyArray_DESCR instead of var->descr

上级 98f066da
...@@ -59,10 +59,10 @@ class BROKEN_ON_PURPOSE_Add(gof.Op): ...@@ -59,10 +59,10 @@ class BROKEN_ON_PURPOSE_Add(gof.Op):
if (PyArray_NDIM(%(a)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); %(fail)s;} if (PyArray_NDIM(%(a)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); %(fail)s;}
if (PyArray_NDIM(%(b)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); %(fail)s;} if (PyArray_NDIM(%(b)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); %(fail)s;}
if (%(a)s->descr->type_num != NPY_DOUBLE) if (PyArray_DESCR(%(a)s)->type_num != NPY_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_DOUBLE"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_DOUBLE"); %(fail)s;}
if (%(b)s->descr->type_num != NPY_DOUBLE) if (PyArray_DESCR(%(b)s)->type_num != NPY_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;}
if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(b)s)[0]) if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(b)s)[0])
...@@ -75,7 +75,7 @@ class BROKEN_ON_PURPOSE_Add(gof.Op): ...@@ -75,7 +75,7 @@ class BROKEN_ON_PURPOSE_Add(gof.Op):
{Py_XDECREF(%(z)s);} {Py_XDECREF(%(z)s);}
npy_intp dims[] = {0}; npy_intp dims[] = {0};
dims[0] = PyArray_DIMS(%(b)s)[0]; dims[0] = PyArray_DIMS(%(b)s)[0];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(b)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, PyArray_DESCR(%(b)s)->type_num);
} }
{ {
...@@ -150,13 +150,13 @@ class WeirdBrokenOp(gof.Op): ...@@ -150,13 +150,13 @@ class WeirdBrokenOp(gof.Op):
else: else:
z_code = """ z_code = """
{Py_XDECREF(%(z)s);} {Py_XDECREF(%(z)s);}
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(a)s), %(a)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(a)s), PyArray_DESCR(%(a)s)->type_num);
""" """
prep_vars = """ prep_vars = """
//the output array has size M x N //the output array has size M x N
npy_intp M = PyArray_DIMS(%(a)s)[0]; npy_intp M = PyArray_DIMS(%(a)s)[0];
npy_intp Sa = %(a)s->strides[0] / %(a)s->descr->elsize; npy_intp Sa = %(a)s->strides[0] / PyArray_DESCR(%(a)s)->elsize;
npy_intp Sz = %(z)s->strides[0] / %(z)s->descr->elsize; npy_intp Sz = %(z)s->strides[0] / PyArray_DESCR(%(z)s)->elsize;
npy_double * Da = (npy_double*)%(a)s->data; npy_double * Da = (npy_double*)%(a)s->data;
npy_double * Dz = (npy_double*)%(z)s->data; npy_double * Dz = (npy_double*)%(z)s->data;
...@@ -606,10 +606,10 @@ class BrokenCImplementationAdd(gof.Op): ...@@ -606,10 +606,10 @@ class BrokenCImplementationAdd(gof.Op):
if (PyArray_NDIM(%(a)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;} if (PyArray_NDIM(%(a)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;}
if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;}
if (%(a)s->descr->type_num != NPY_FLOAT) if (PyArray_DESCR(%(a)s)->type_num != NPY_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;}
if (%(b)s->descr->type_num != NPY_FLOAT) if (PyArray_DESCR(%(b)s)->type_num != NPY_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;}
if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(a)s)[1]) if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(a)s)[1])
...@@ -643,7 +643,7 @@ class BrokenCImplementationAdd(gof.Op): ...@@ -643,7 +643,7 @@ class BrokenCImplementationAdd(gof.Op):
npy_intp dims[] = {0, 0}; npy_intp dims[] = {0, 0};
dims[0] = PyArray_DIMS(%(b)s)[0]; dims[0] = PyArray_DIMS(%(b)s)[0];
dims[1] = PyArray_DIMS(%(b)s)[1]; dims[1] = PyArray_DIMS(%(b)s)[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(b)s->descr->type_num); %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, PyArray_DESCR(%(b)s)->type_num);
} }
// Let us assume that %(z)s is c_contiguous // Let us assume that %(z)s is c_contiguous
......
...@@ -123,7 +123,7 @@ class GpuDot22Scalar(GpuOp): ...@@ -123,7 +123,7 @@ class GpuDot22Scalar(GpuOp):
fail = sub['fail'] fail = sub['fail']
return """ return """
#define REAL float #define REAL float
float %(name)s_a = (%(a)s->descr->type_num == NPY_FLOAT) float %(name)s_a = (PyArray_TYPE(%(a)s) == NPY_FLOAT)
? (REAL)(((float*)%(a)s->data)[0]) ? (REAL)(((float*)%(a)s->data)[0])
: (REAL)(((double*)%(a)s->data)[0]); : (REAL)(((double*)%(a)s->data)[0]);
#undef REAL #undef REAL
...@@ -231,11 +231,11 @@ class GpuGemm(GpuOp): ...@@ -231,11 +231,11 @@ class GpuGemm(GpuOp):
print >> sio, """ print >> sio, """
#define REAL float #define REAL float
float %(name)s_a = (%(a)s->descr->type_num == NPY_FLOAT) float %(name)s_a = (PyArray_TYPE(%(a)s) == NPY_FLOAT)
? (REAL)(((float*)%(a)s->data)[0]) ? (REAL)(((float*)%(a)s->data)[0])
: (REAL)(((double*)%(a)s->data)[0]); : (REAL)(((double*)%(a)s->data)[0]);
float %(name)s_b = (%(b)s->descr->type_num == NPY_FLOAT) ? float %(name)s_b = (PyArray_TYPE(%(b)s) == NPY_FLOAT) ?
(REAL)(((float*)%(b)s->data)[0]) (REAL)(((float*)%(b)s->data)[0])
: (REAL)(((double*)%(b)s->data)[0]); : (REAL)(((double*)%(b)s->data)[0]);
#undef REAL #undef REAL
......
...@@ -153,7 +153,7 @@ class CURAND_Base(GpuOp): ...@@ -153,7 +153,7 @@ class CURAND_Base(GpuOp):
%(ndim)s, %(size)s->dimensions[0]); %(ndim)s, %(size)s->dimensions[0]);
%(fail)s %(fail)s
} }
if (%(size)s->descr->type_num != NPY_INT32) if (PyArray_DESCR(%(size)s)->type_num != NPY_INT32)
{ {
PyErr_SetString(PyExc_ValueError, "size must be int32"); PyErr_SetString(PyExc_ValueError, "size must be int32");
%(fail)s %(fail)s
......
...@@ -272,7 +272,7 @@ class mrg_uniform(mrg_uniform_base): ...@@ -272,7 +272,7 @@ class mrg_uniform(mrg_uniform_base):
%(ndim)s, int(PyArray_DIMS(%(size)s)[0])); %(ndim)s, int(PyArray_DIMS(%(size)s)[0]));
%(fail)s %(fail)s
} }
if (%(size)s->descr->type_num != NPY_INT32) if (PyArray_DESCR(%(size)s)->type_num != NPY_INT32)
{ {
PyErr_SetString(PyExc_ValueError, "size must be int32"); PyErr_SetString(PyExc_ValueError, "size must be int32");
%(fail)s %(fail)s
...@@ -306,7 +306,7 @@ class mrg_uniform(mrg_uniform_base): ...@@ -306,7 +306,7 @@ class mrg_uniform(mrg_uniform_base):
PyErr_Format(PyExc_ValueError, "rstate must have 6 columns"); PyErr_Format(PyExc_ValueError, "rstate must have 6 columns");
%(fail)s %(fail)s
} }
if (%(o_rstate)s->descr->type_num != NPY_INT32) if (PyArray_DESCR(%(o_rstate)s)->type_num != NPY_INT32)
{ {
PyErr_SetString(PyExc_ValueError, "rstate must be int32"); PyErr_SetString(PyExc_ValueError, "rstate must be int32");
%(fail)s %(fail)s
...@@ -514,7 +514,7 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp): ...@@ -514,7 +514,7 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
%(ndim)s, PyArray_DIMS(%(size)s)[0]); %(ndim)s, PyArray_DIMS(%(size)s)[0]);
%(fail)s %(fail)s
} }
if (%(size)s->descr->type_num != NPY_INT32) if (PyArray_DESCR(%(size)s)->type_num != NPY_INT32)
{ {
PyErr_SetString(PyExc_ValueError, "size must be int32"); PyErr_SetString(PyExc_ValueError, "size must be int32");
%(fail)s %(fail)s
......
...@@ -3011,10 +3011,10 @@ class StructuredDotGradCSC(gof.Op): ...@@ -3011,10 +3011,10 @@ class StructuredDotGradCSC(gof.Op):
if (PyArray_NDIM(%(_indices)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} if (PyArray_NDIM(%(_indices)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;}
if (PyArray_NDIM(%(_indptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (PyArray_NDIM(%(_indptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;}
if( %(_indices)s->descr->type_num != NPY_INT32) { if( PyArray_DESCR(%(_indices)s)->type_num != NPY_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != NPY_INT32) if( PyArray_DESCR(%(_indptr)s)->type_num != NPY_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if( PyArray_DIMS(%(_d)s)[1] != PyArray_DIMS(%(_g)s)[1]) if( PyArray_DIMS(%(_d)s)[1] != PyArray_DIMS(%(_g)s)[1])
...@@ -3024,18 +3024,18 @@ class StructuredDotGradCSC(gof.Op): ...@@ -3024,18 +3024,18 @@ class StructuredDotGradCSC(gof.Op):
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0])) || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]))
{ {
Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(_indices)s), %(_g)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(_indices)s), PyArray_DESCR(%(_g)s)->type_num);
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; npy_intp nnz = PyArray_DIMS(%(_indices)s)[0];
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; //TODO: error checking with this npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; //TODO: error checking with this
npy_intp Sindices = %(_indices)s->strides[0]/%(_indices)s->descr->elsize; npy_intp Sindices = %(_indices)s->strides[0]/PyArray_DESCR(%(_indices)s)->elsize;
npy_intp Sindptr = %(_indptr)s->strides[0]/%(_indptr)s->descr->elsize; npy_intp Sindptr = %(_indptr)s->strides[0]/PyArray_DESCR(%(_indptr)s)->elsize;
const npy_intp Sd1 = %(_d)s->strides[1]/%(_d)s->descr->elsize; const npy_intp Sd1 = %(_d)s->strides[1]/PyArray_DESCR(%(_d)s)->elsize;
const npy_intp Sg1 = %(_g)s->strides[1]/%(_g)s->descr->elsize; const npy_intp Sg1 = %(_g)s->strides[1]/PyArray_DESCR(%(_g)s)->elsize;
const npy_intp K = PyArray_DIMS(%(_d)s)[1]; const npy_intp K = PyArray_DIMS(%(_d)s)[1];
...@@ -3147,10 +3147,10 @@ class StructuredDotGradCSR(gof.Op): ...@@ -3147,10 +3147,10 @@ class StructuredDotGradCSR(gof.Op):
if (PyArray_NDIM(%(_indices)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;} if (PyArray_NDIM(%(_indices)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); %(fail)s;}
if (PyArray_NDIM(%(_indptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;} if (PyArray_NDIM(%(_indptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); %(fail)s;}
if( %(_indices)s->descr->type_num != NPY_INT32) { if( PyArray_DESCR(%(_indices)s)->type_num != NPY_INT32) {
PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;}
if( %(_indptr)s->descr->type_num != NPY_INT32) if( PyArray_DESCR(%(_indptr)s)->type_num != NPY_INT32)
{PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;}
if( PyArray_DIMS(%(_d)s)[1] != PyArray_DIMS(%(_g)s)[1]) if( PyArray_DIMS(%(_d)s)[1] != PyArray_DIMS(%(_g)s)[1])
...@@ -3160,7 +3160,7 @@ class StructuredDotGradCSR(gof.Op): ...@@ -3160,7 +3160,7 @@ class StructuredDotGradCSR(gof.Op):
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0])) || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]))
{ {
Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(_indices)s), %(_g)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, PyArray_DIMS(%(_indices)s), PyArray_DESCR(%(_g)s)->type_num);
} }
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
...@@ -3168,11 +3168,11 @@ class StructuredDotGradCSR(gof.Op): ...@@ -3168,11 +3168,11 @@ class StructuredDotGradCSR(gof.Op):
// extract number of rows // extract number of rows
npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; //TODO: error checking with this npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; //TODO: error checking with this
npy_intp Sindices = %(_indices)s->strides[0]/%(_indices)s->descr->elsize; npy_intp Sindices = %(_indices)s->strides[0]/PyArray_DESCR(%(_indices)s)->elsize;
npy_intp Sindptr = %(_indptr)s->strides[0]/%(_indptr)s->descr->elsize; npy_intp Sindptr = %(_indptr)s->strides[0]/PyArray_DESCR(%(_indptr)s)->elsize;
const npy_intp Sd1 = %(_d)s->strides[1]/%(_d)s->descr->elsize; const npy_intp Sd1 = %(_d)s->strides[1]/PyArray_DESCR(%(_d)s)->elsize;
const npy_intp Sg1 = %(_g)s->strides[1]/%(_g)s->descr->elsize; const npy_intp Sg1 = %(_g)s->strides[1]/PyArray_DESCR(%(_g)s)->elsize;
const npy_intp K = PyArray_DIMS(%(_d)s)[1]; const npy_intp K = PyArray_DIMS(%(_d)s)[1];
......
差异被折叠。
...@@ -4042,10 +4042,10 @@ class Subtensor(Op): ...@@ -4042,10 +4042,10 @@ class Subtensor(Op):
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
Py_INCREF(%(x)s->descr); Py_INCREF(PyArray_DESCR(%(x)s));
PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr( PyArrayObject * xview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type, &PyArray_Type,
%(x)s->descr, PyArray_DESCR(%(x)s),
%(view_ndim)s, %(view_ndim)s,
PyArray_DIMS(%(x)s), PyArray_DIMS(%(x)s),
PyArray_STRIDES(%(x)s), PyArray_STRIDES(%(x)s),
......
...@@ -493,8 +493,8 @@ class GemmRelated(Op): ...@@ -493,8 +493,8 @@ class GemmRelated(Op):
declare_NS = """ declare_NS = """
int unit = 0; int unit = 0;
int type_num = %(_x)s->descr->type_num; int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Nx = PyArray_DIMS(%(_x)s);
npy_intp* Ny = PyArray_DIMS(%(_y)s); npy_intp* Ny = PyArray_DIMS(%(_y)s);
...@@ -529,31 +529,31 @@ class GemmRelated(Op): ...@@ -529,31 +529,31 @@ class GemmRelated(Op):
} }
""" """
check_xyz_double_or_float = """ check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
&& (%(_x)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
&& (%(_y)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_zout)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE)
&& (%(_zout)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num) if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
||(%(_x)s->descr->type_num != %(_zout)s->descr->type_num)) ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
""" """
#it is not necessary that a or b have the same type as x,y,z #it is not necessary that a or b have the same type as x,y,z
check_ab_double_or_float = """ check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
&& (%(_a)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_b)s)->type_num != NPY_DOUBLE)
&& (%(_b)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_b)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
""" """
...@@ -919,7 +919,7 @@ class Gemm(GemmRelated): ...@@ -919,7 +919,7 @@ class Gemm(GemmRelated):
Nz = PyArray_DIMS(%(_zout)s); Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s); Sz = PyArray_STRIDES(%(_zout)s);
if (%(_zout)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
{ {
float * zoutdata = (float*)PyArray_DATA(%(_zout)s); float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(float); int zoi = Sz[0] / sizeof(float);
...@@ -935,7 +935,7 @@ class Gemm(GemmRelated): ...@@ -935,7 +935,7 @@ class Gemm(GemmRelated):
} }
} }
} }
else if (%(_zout)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
{ {
double * zoutdata = (double*) PyArray_DATA(%(_zout)s); double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(double); int zoi = Sz[0] / sizeof(double);
...@@ -961,20 +961,20 @@ class Gemm(GemmRelated): ...@@ -961,20 +961,20 @@ class Gemm(GemmRelated):
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
float a = (%(_a)s->descr->type_num == NPY_FLOAT) float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
float b = (%(_b)s->descr->type_num == NPY_FLOAT) ? float b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ?
(REAL)(((float*)PyArray_DATA(%(_b)s))[0]) (REAL)(((float*)PyArray_DATA(%(_b)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]);
#undef REAL #undef REAL
""" """
case_double_ab_constants = """ case_double_ab_constants = """
#define REAL double #define REAL double
double a = (%(_a)s->descr->type_num == NPY_FLOAT) double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
double b = (%(_b)s->descr->type_num == NPY_FLOAT) ? double b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ?
(REAL)(((float*)PyArray_DATA(%(_b)s))[0]) (REAL)(((float*)PyArray_DATA(%(_b)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]);
#undef REAL #undef REAL
...@@ -1753,15 +1753,15 @@ class Dot22Scalar(GemmRelated): ...@@ -1753,15 +1753,15 @@ class Dot22Scalar(GemmRelated):
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
check_ab_double_or_float = """ check_ab_double_or_float = """
if ((%(_a)s->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
&& (%(_a)s->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, {PyErr_SetString(PyExc_NotImplementedError,
"type(a) is not double or float"); %(fail)s;} "type(a) is not double or float"); %(fail)s;}
""" """
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
float a = (%(_a)s->descr->type_num == NPY_FLOAT) float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
#undef REAL #undef REAL
...@@ -1770,7 +1770,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1770,7 +1770,7 @@ class Dot22Scalar(GemmRelated):
case_double_ab_constants = """ case_double_ab_constants = """
#define REAL double #define REAL double
double a = (%(_a)s->descr->type_num == NPY_FLOAT) double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT)
? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0])
: (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]);
#undef REAL #undef REAL
......
...@@ -42,9 +42,9 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -42,9 +42,9 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
if (PyArray_NDIM(%(a)s) != 0) if (PyArray_NDIM(%(a)s) != 0)
{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); %(fail)s;}
if (%(A)s->descr->type_num != %(x)s->descr->type_num) if (PyArray_DESCR(%(A)s)->type_num != PyArray_DESCR(%(x)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. x"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "A vs. x"); %(fail)s; }
if (%(A)s->descr->type_num != %(y)s->descr->type_num) if (PyArray_DESCR(%(A)s)->type_num != PyArray_DESCR(%(y)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
if (PyArray_DIMS(%(A)s)[0] != PyArray_DIMS(%(x)s)[0]) if (PyArray_DIMS(%(A)s)[0] != PyArray_DIMS(%(x)s)[0])
...@@ -60,8 +60,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -60,8 +60,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s; %(fail)s;
} }
if (%(A)s->descr->type_num == NPY_DOUBLE) { elemsize = 8; } if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == NPY_FLOAT) { elemsize = 4;} else if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT) { elemsize = 4;}
else else
{ {
PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); PyErr_SetString(PyExc_NotImplementedError, "complex CGer");
...@@ -101,7 +101,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -101,7 +101,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
PyErr_SetString(PyExc_AssertionError, "%(Z)s != %(A)s"); PyErr_SetString(PyExc_AssertionError, "%(Z)s != %(A)s");
%(fail)s %(fail)s
} }
if (%(Z)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
float * zoutdata = (float*)PyArray_DATA(%(Z)s); float * zoutdata = (float*)PyArray_DATA(%(Z)s);
const float * zdata = (float*)PyArray_DATA(%(A)s); const float * zdata = (float*)PyArray_DATA(%(A)s);
...@@ -117,7 +117,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -117,7 +117,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
} }
} }
else if (%(Z)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
double * zoutdata = (double*) PyArray_DATA(%(Z)s); double * zoutdata = (double*) PyArray_DATA(%(Z)s);
const double * zdata = (double*)PyArray_DATA(%(A)s); const double * zdata = (double*)PyArray_DATA(%(A)s);
...@@ -178,7 +178,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -178,7 +178,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
if (PyArray_STRIDES(%(Z)s)[0] == elemsize) if (PyArray_STRIDES(%(Z)s)[0] == elemsize)
{ {
if (%(Z)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n"); //fprintf(stderr, "A\\n");
float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; float alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
...@@ -187,7 +187,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -187,7 +187,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(float*)y_data, &Sy, (float*)y_data, &Sy,
(float*)(PyArray_DATA(%(Z)s)), &Sz1); (float*)(PyArray_DATA(%(Z)s)), &Sz1);
} }
else if (%(Z)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz0, &Nz1, &alpha, dger_(&Nz0, &Nz1, &alpha,
...@@ -203,7 +203,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -203,7 +203,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else if (PyArray_STRIDES(%(Z)s)[1] == elemsize) else if (PyArray_STRIDES(%(Z)s)[1] == elemsize)
{ {
if (%(Z)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(Z)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1); //fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0]; float alpha = ((dtype_%(a)s*)(PyArray_DATA(%(a)s)))[0];
...@@ -214,7 +214,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -214,7 +214,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(float*)x_data, &Sx, (float*)x_data, &Sx,
(float*)(PyArray_DATA(%(Z)s)), &Sz0); (float*)(PyArray_DATA(%(Z)s)), &Sz0);
} }
else if (%(Z)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(Z)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0]; double alpha = ((dtype_%(a)s*)PyArray_DATA(%(a)s))[0];
dger_(&Nz1, &Nz0, &alpha, dger_(&Nz1, &Nz0, &alpha,
...@@ -316,9 +316,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -316,9 +316,9 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(fail)s; %(fail)s;
} }
if (%(aa)s->descr->type_num != %(xx)s->descr->type_num) if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(xx)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
if (%(aa)s->descr->type_num != %(yy)s->descr->type_num) if (PyArray_DESCR(%(aa)s)->type_num != PyArray_DESCR(%(yy)s)->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (PyArray_DIMS(%(xx)s)[0] != PyArray_DIMS(%(aa)s)[0]) if (PyArray_DIMS(%(xx)s)[0] != PyArray_DIMS(%(aa)s)[0])
...@@ -334,8 +334,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -334,8 +334,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(fail)s; %(fail)s;
} }
if (%(aa)s->descr->type_num == NPY_DOUBLE) { elemsize = 8; } if (PyArray_DESCR(%(aa)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
else if (%(aa)s->descr->type_num == NPY_FLOAT) { elemsize = 4;} else if (PyArray_DESCR(%(aa)s)->type_num == NPY_FLOAT) { elemsize = 4;}
else { else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s; %(fail)s;
...@@ -365,7 +365,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -365,7 +365,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
if (dbeta != 0) if (dbeta != 0)
{ {
if (%(zz)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(zz)s)->type_num == NPY_FLOAT)
{ {
float * zoutdata = (float*)PyArray_DATA(%(zz)s); float * zoutdata = (float*)PyArray_DATA(%(zz)s);
const float * zdata = (float*)PyArray_DATA(%(aa)s); const float * zdata = (float*)PyArray_DATA(%(aa)s);
...@@ -376,7 +376,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -376,7 +376,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
zoutdata[Zi*i] = fbeta * zdata[Ai*i]; zoutdata[Zi*i] = fbeta * zdata[Ai*i];
} }
} }
else if (%(xx)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
{ {
double * zoutdata = (double*) PyArray_DATA(%(zz)s); double * zoutdata = (double*) PyArray_DATA(%(zz)s);
const double * zdata = (double*)PyArray_DATA(%(aa)s); const double * zdata = (double*)PyArray_DATA(%(aa)s);
...@@ -460,7 +460,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -460,7 +460,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
if (PyArray_STRIDES(%(xx)s)[0] == elemsize) if (PyArray_STRIDES(%(xx)s)[0] == elemsize)
{ {
if (%(xx)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "A\\n"); //fprintf(stderr, "A\\n");
float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
...@@ -471,7 +471,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -471,7 +471,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
&fbeta, &fbeta,
(float*)zz_data, &Sz); (float*)zz_data, &Sz);
} }
else if (%(xx)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&NOTRANS, &Nx0, &Nx1, dgemv_(&NOTRANS, &Nx0, &Nx1,
...@@ -490,7 +490,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -490,7 +490,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else if (PyArray_STRIDES(%(xx)s)[1] == elemsize) else if (PyArray_STRIDES(%(xx)s)[1] == elemsize)
{ {
if (%(xx)s->descr->type_num == NPY_FLOAT) if (PyArray_DESCR(%(xx)s)->type_num == NPY_FLOAT)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", //fprintf(stderr, "B %%i %%i %%i %%i\\n",
// Nz0, Nz1, Sz0, Sz1); // Nz0, Nz1, Sz0, Sz1);
...@@ -504,7 +504,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -504,7 +504,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
&fbeta, &fbeta,
(float*)zz_data, &Sz); (float*)zz_data, &Sz);
} }
else if (%(xx)s->descr->type_num == NPY_DOUBLE) else if (PyArray_DESCR(%(xx)s)->type_num == NPY_DOUBLE)
{ {
double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
dgemv_(&TRANS, &Nx1, &Nx0, dgemv_(&TRANS, &Nx1, &Nx0,
......
...@@ -794,8 +794,8 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -794,8 +794,8 @@ def ____gemm_code(check_ab, a_init, b_init):
return """ return """
const char * error_string = NULL; const char * error_string = NULL;
int type_num = _x->descr->type_num; int type_num = PyArray_DESCR(_x)->type_num;
int type_size = _x->descr->elsize; // in bytes int type_size = PyArray_DESCR(_x)->elsize; // in bytes
npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Nx = PyArray_DIMS(_x);
npy_intp* Ny = PyArray_DIMS(_y); npy_intp* Ny = PyArray_DIMS(_y);
...@@ -815,20 +815,20 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -815,20 +815,20 @@ def ____gemm_code(check_ab, a_init, b_init):
%(check_ab)s %(check_ab)s
if ((_x->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (_x->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
if ((_y->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE)
&& (_y->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(_y)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
if ((_y->descr->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE)
&& (_y->descr->type_num != NPY_FLOAT)) && (PyArray_DESCR(_y)->type_num != NPY_FLOAT))
goto _dot_execute_fallback; goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num) if ((PyArray_DESCR(_x)->type_num != PyArray_DESCR(_y)->type_num)
||(_x->descr->type_num != _z->descr->type_num)) ||(PyArray_DESCR(_x)->type_num != PyArray_DESCR(_z)->type_num))
goto _dot_execute_fallback; goto _dot_execute_fallback;
......
...@@ -311,7 +311,7 @@ class DimShuffle(Op): ...@@ -311,7 +311,7 @@ class DimShuffle(Op):
str(nd_out) + str(nd_out) +
'-1] == 0) strides[' + '-1] == 0) strides[' +
str(nd_out) + str(nd_out) +
'-1] = %(basename)s->descr->elsize' '-1] = PyArray_DESCR(%(basename)s)->elsize'
) )
for i in xrange(nd_out - 2, -1, -1): for i in xrange(nd_out - 2, -1, -1):
strides_statements.append( strides_statements.append(
...@@ -333,7 +333,13 @@ class DimShuffle(Op): ...@@ -333,7 +333,13 @@ class DimShuffle(Op):
#recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED #recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
'PyArray_UpdateFlags(%(res)s, NPY_ARRAY_UPDATE_ALL)', 'PyArray_UpdateFlags(%(res)s, NPY_ARRAY_UPDATE_ALL)',
#we are making a view in both inplace and non-inplace cases #we are making a view in both inplace and non-inplace cases
'PyArray_BASE(%(res)s) = (PyObject*)%(basename)s', """
#if NPY_VERSION < 0x01000009
PyArray_BASE(%(res)s) = (PyObject*)%(basename)s;
#else
PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
#endif
"""
'}'] '}']
full_code = statements(check_input_nd full_code = statements(check_input_nd
......
...@@ -296,7 +296,7 @@ class Conv3D(theano.Op): ...@@ -296,7 +296,7 @@ class Conv3D(theano.Op):
PyArray_DIMS(%(H)s)[3]!=dims[3] || PyArray_DIMS(%(H)s)[3]!=dims[3] ||
PyArray_DIMS(%(H)s)[4]!=dims[4]){ PyArray_DIMS(%(H)s)[4]!=dims[4]){
Py_XDECREF(%(H)s); Py_XDECREF(%(H)s);
%(H)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, %(V)s->descr->type_num); %(H)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, PyArray_DESCR(%(V)s)->type_num);
if (!(%(H)s)) { if (!(%(H)s)) {
PyErr_Format(PyExc_MemoryError,"Conv3D: Could not allocate output."); PyErr_Format(PyExc_MemoryError,"Conv3D: Could not allocate output.");
%(fail)s %(fail)s
......
...@@ -217,7 +217,7 @@ class ConvGrad3D(theano.Op): ...@@ -217,7 +217,7 @@ class ConvGrad3D(theano.Op):
PyArray_DIMS(%(dCdW)s)[3]!=dims[3] || PyArray_DIMS(%(dCdW)s)[3]!=dims[3] ||
PyArray_DIMS(%(dCdW)s)[4]!=dims[4] ){ PyArray_DIMS(%(dCdW)s)[4]!=dims[4] ){
Py_XDECREF(%(dCdW)s); Py_XDECREF(%(dCdW)s);
%(dCdW)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, %(V)s->descr->type_num); %(dCdW)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, PyArray_DESCR(%(V)s)->type_num);
if (!(%(dCdW)s)) { if (!(%(dCdW)s)) {
PyErr_Format(PyExc_MemoryError,"ConvGrad3D: Could not allocate dCdW"); PyErr_Format(PyExc_MemoryError,"ConvGrad3D: Could not allocate dCdW");
......
...@@ -228,7 +228,7 @@ class ConvTransp3D(theano.Op): ...@@ -228,7 +228,7 @@ class ConvTransp3D(theano.Op):
PyArray_DIMS(%(R)s)[4]!=dims[4]) PyArray_DIMS(%(R)s)[4]!=dims[4])
{ {
Py_XDECREF(%(R)s); Py_XDECREF(%(R)s);
%(R)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, %(H)s->descr->type_num); %(R)s = (PyArrayObject *) PyArray_SimpleNew(5, dims, PyArray_DESCR(%(H)s)->type_num);
if (!(%(R)s)) { if (!(%(R)s)) {
PyErr_Format(PyExc_MemoryError, "ConvTransp3D: could not allocate R"); PyErr_Format(PyExc_MemoryError, "ConvTransp3D: could not allocate R");
%(fail)s %(fail)s
......
...@@ -117,14 +117,14 @@ class SoftmaxWithBias(gof.Op): ...@@ -117,14 +117,14 @@ class SoftmaxWithBias(gof.Op):
PyErr_SetString(PyExc_ValueError, "b not 1d tensor"); PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
%(fail)s; %(fail)s;
} }
if ((%(x)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(x)s)->type_num != NPY_DOUBLE) &&
(%(x)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(x)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, "a not float"); PyErr_SetString(PyExc_TypeError, "a not float");
%(fail)s; %(fail)s;
} }
if ((%(b)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(b)s)->type_num != NPY_DOUBLE) &&
(%(b)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(b)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, "b not float"); PyErr_SetString(PyExc_TypeError, "b not float");
%(fail)s; %(fail)s;
...@@ -263,15 +263,15 @@ class SoftmaxGrad(gof.Op): ...@@ -263,15 +263,15 @@ class SoftmaxGrad(gof.Op):
dy, sm = inp dy, sm = inp
dx, = out dx, = out
return ''' return '''
if ((%(dy)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(dy)s)->type_num != NPY_DOUBLE) &&
(%(dy)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(dy)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"types should be float or float64"); "types should be float or float64");
%(fail)s; %(fail)s;
} }
if ((%(sm)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(sm)s)->type_num != NPY_DOUBLE) &&
(%(sm)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(sm)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"types should be float or float64"); "types should be float or float64");
...@@ -778,10 +778,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -778,10 +778,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor"); PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor");
%(fail)s; %(fail)s;
} }
if ((%(y_idx)s->descr->type_num != NPY_INT64) if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (%(y_idx)s->descr->type_num != NPY_INT32) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (%(y_idx)s->descr->type_num != NPY_INT16) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (%(y_idx)s->descr->type_num != NPY_INT8)) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64"); "y_idx not int8, int16, int32, or int64");
...@@ -914,24 +914,24 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -914,24 +914,24 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
y_idx_type = node.inputs[2].type.dtype_specs()[1] y_idx_type = node.inputs[2].type.dtype_specs()[1]
return """ return """
if ((%(dnll)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(dnll)s)->type_num != NPY_DOUBLE) &&
(%(dnll)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(dnll)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"dnll type should be float32 or float64"); "dnll type should be float32 or float64");
%(fail)s; %(fail)s;
} }
if ((%(sm)s->descr->type_num != NPY_DOUBLE) && if ((PyArray_DESCR(%(sm)s)->type_num != NPY_DOUBLE) &&
(%(sm)s->descr->type_num != NPY_FLOAT)) (PyArray_DESCR(%(sm)s)->type_num != NPY_FLOAT))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"sm type should be float32 or float64"); "sm type should be float32 or float64");
%(fail)s; %(fail)s;
} }
if ((%(y_idx)s->descr->type_num != NPY_INT64) if ((PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT64)
&& (%(y_idx)s->descr->type_num != NPY_INT32) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT32)
&& (%(y_idx)s->descr->type_num != NPY_INT16) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT16)
&& (%(y_idx)s->descr->type_num != NPY_INT8)) && (PyArray_DESCR(%(y_idx)s)->type_num != NPY_INT8))
{ {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"y_idx not int8, int16, int32, or int64"); "y_idx not int8, int16, int32, or int64");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论