提交 17f1b0fa authored 作者: Hengjean's avatar Hengjean

Fixed all bugs.

上级 2774de58
...@@ -308,7 +308,10 @@ def get_nothing(r, name, sub): ...@@ -308,7 +308,10 @@ def get_nothing(r, name, sub):
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if any([c == 'output' or getattr(c.op, 'check_input', True) for (c, _)
class helper:
check_input = False
if any([getattr(getattr(c, 'op', helper), 'check_input', True) for (c, _)
in r.clients]) or (r.owner and getattr(r.owner.op, in r.clients]) or (r.owner and getattr(r.owner.op,
'check_input', True)): 'check_input', True)):
......
...@@ -59,7 +59,14 @@ class MultinomialFromUniform(Op): ...@@ -59,7 +59,14 @@ class MultinomialFromUniform(Op):
def c_code(self, node, name, ins, outs, sub): def c_code(self, node, name, ins, outs, sub):
(pvals, unis) = ins (pvals, unis) = ins
(z,) = outs (z,) = outs
if self.odtype == 'auto':
t = "PyArray_TYPE((PyArrayObject*) py_%(pvals)s)" % locals()
else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'):
t = t.replace('theano_complex', 'NPY_COMPLEX')
else:
t = t.upper()
fail = sub['fail'] fail = sub['fail']
return """ return """
if (PyArray_NDIM(%(pvals)s) != 2) if (PyArray_NDIM(%(pvals)s) != 2)
...@@ -87,7 +94,7 @@ class MultinomialFromUniform(Op): ...@@ -87,7 +94,7 @@ class MultinomialFromUniform(Op):
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_ZEROS(2, %(z)s = (PyArrayObject*) PyArray_ZEROS(2,
PyArray_DIMS(%(pvals)s), PyArray_DIMS(%(pvals)s),
type_num_%(z)s, %(t)s,
0); 0);
if (!%(z)s) if (!%(z)s)
{ {
......
...@@ -335,7 +335,7 @@ class Images2Neibs(Op): ...@@ -335,7 +335,7 @@ class Images2Neibs(Op):
%(z)s = (PyArrayObject*) PyArray_EMPTY(2, %(z)s = (PyArrayObject*) PyArray_EMPTY(2,
dims, dims,
type_num_%(ten4)s, PyArray_TYPE((PyArrayObject*) py_%(ten4)s),
0); 0);
if (!%(z)s) if (!%(z)s)
......
...@@ -257,13 +257,13 @@ class Scalar(Type): ...@@ -257,13 +257,13 @@ class Scalar(Type):
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
if(check_input): if(check_input):
pre = """ pre = """
%(dtype)s %(name)s; typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1]) """ % dict(name=name, dtype=self.dtype_specs()[1])
else: else:
pre = "" pre = ""
return pre + """ return pre + """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead. %(dtype)s %(name)s;
typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1]) """ % dict(name=name, dtype=self.dtype_specs()[1])
def c_init(self, name, sub): def c_init(self, name, sub):
...@@ -462,7 +462,7 @@ class Scalar(Type): ...@@ -462,7 +462,7 @@ class Scalar(Type):
return ["import_array();"] return ["import_array();"]
def c_code_cache_version(self): def c_code_cache_version(self):
return (12, numpy.__version__) return (13, numpy.__version__)
def get_shape_info(self, obj): def get_shape_info(self, obj):
return obj.itemsize return obj.itemsize
......
...@@ -2425,7 +2425,7 @@ class Alloc(gof.Op): ...@@ -2425,7 +2425,7 @@ class Alloc(gof.Op):
{ {
Py_XDECREF(%(zz)s); Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s, %(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
shape, type_num_%(vv)s); shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s));
if (!%(zz)s) if (!%(zz)s)
{ {
PyErr_SetString(PyExc_MemoryError, "alloc failed"); PyErr_SetString(PyExc_MemoryError, "alloc failed");
......
...@@ -1026,7 +1026,7 @@ class Gemm(GemmRelated): ...@@ -1026,7 +1026,7 @@ class Gemm(GemmRelated):
dims[0] = PyArray_DIMS(%(_z)s)[0]; dims[0] = PyArray_DIMS(%(_z)s)[0];
dims[1] = PyArray_DIMS(%(_z)s)[1]; dims[1] = PyArray_DIMS(%(_z)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_z)s); PyArray_TYPE((PyArrayObject*) py_%(_z)s));
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -1627,7 +1627,7 @@ class Dot22(GemmRelated): ...@@ -1627,7 +1627,7 @@ class Dot22(GemmRelated):
dims[0] = PyArray_DIMS(%(_x)s)[0]; dims[0] = PyArray_DIMS(%(_x)s)[0];
dims[1] = PyArray_DIMS(%(_y)s)[1]; dims[1] = PyArray_DIMS(%(_y)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_x)s); PyArray_TYPE((PyArrayObject*) py_%(_x)s));
//fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
......
...@@ -353,7 +353,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -353,7 +353,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ {
if (%(zz)s) Py_XDECREF(%(zz)s); if (%(zz)s) Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1, %(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(aa)s), type_num_%(aa)s); PyArray_DIMS(%(aa)s), PyArray_TYPE((PyArrayObject*) py_%(aa)s));
if(!%(zz)s) { if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output"); "failed to alloc gemv output");
......
...@@ -121,7 +121,9 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -121,7 +121,9 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
created, otherwise it will be c order. created, otherwise it will be c order.
""" """
type = dtype.upper()
if type.startswith('THEANO_COMPLEX'):
type = type.replace('THEANO_COMPLEX', 'NPY_COMPLEX')
nd = len(loop_orders[0]) nd = len(loop_orders[0])
init_dims = "" init_dims = ""
# For each dimension, the tensors are either all broadcasted, in # For each dimension, the tensors are either all broadcasted, in
...@@ -142,7 +144,6 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -142,7 +144,6 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest # contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS. # stride. Right now, it is allocated to be C_CONTIGUOUS.
return """ return """
{ {
npy_intp dims[%(nd)s]; npy_intp dims[%(nd)s];
...@@ -150,7 +151,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -150,7 +151,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
%(init_dims)s %(init_dims)s
if (!%(olv)s) { if (!%(olv)s) {
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims,
type_num_%(olv)s, %(type)s,
%(fortran)s); %(fortran)s);
} }
else { else {
...@@ -162,7 +163,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -162,7 +163,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
// If we can't resize the ndarray we have we can allocate a new one. // If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear(); PyErr_Clear();
Py_XDECREF(%(olv)s); Py_XDECREF(%(olv)s);
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, type_num_%(olv)s, 0); %(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(type)s, 0);
} }
} }
if (!%(olv)s) { if (!%(olv)s) {
......
...@@ -68,13 +68,13 @@ class CumsumOp(theano.Op): ...@@ -68,13 +68,13 @@ class CumsumOp(theano.Op):
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, NPY_MAXDIMS, type_num_%(x)s, %(z)s); PyArray_CumSum(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -83,13 +83,13 @@ class CumsumOp(theano.Op): ...@@ -83,13 +83,13 @@ class CumsumOp(theano.Op):
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) )) if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) ))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, %(axis)s, type_num_%(x)s, %(z)s); PyArray_CumSum(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -177,13 +177,13 @@ class CumprodOp(theano.Op): ...@@ -177,13 +177,13 @@ class CumprodOp(theano.Op):
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, NPY_MAXDIMS, type_num_%(x)s, %(z)s); PyArray_CumProd(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -192,13 +192,13 @@ class CumprodOp(theano.Op): ...@@ -192,13 +192,13 @@ class CumprodOp(theano.Op):
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) )) if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) ))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, %(axis)s, type_num_%(x)s, %(z)s); PyArray_CumProd(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
......
...@@ -148,7 +148,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -148,7 +148,7 @@ class SoftmaxWithBias(gof.Op):
{ {
if (NULL != %(sm)s) Py_XDECREF(%(sm)s); if (NULL != %(sm)s) Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s); PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(sm)s) { if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output"); "failed to alloc sm output");
...@@ -342,7 +342,7 @@ class SoftmaxGrad(gof.Op): ...@@ -342,7 +342,7 @@ class SoftmaxGrad(gof.Op):
Py_XDECREF(%(dx)s); Py_XDECREF(%(dx)s);
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2, %(dx)s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(%(sm)s), PyArray_DIMS(%(sm)s),
type_num_%(sm)s); PyArray_TYPE((PyArrayObject*) py_%(sm)s));
if (!%(dx)s) if (!%(dx)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -463,7 +463,7 @@ class Softmax(gof.Op): ...@@ -463,7 +463,7 @@ class Softmax(gof.Op):
{ {
Py_XDECREF(%(sm)s); Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s); PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(sm)s) { if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output"); "failed to alloc sm output");
...@@ -977,7 +977,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -977,7 +977,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
{ {
if (NULL != %(nll)s) Py_XDECREF(%(nll)s); if (NULL != %(nll)s) Py_XDECREF(%(nll)s);
%(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(y_idx)s), type_num_%(x)s); PyArray_DIMS(%(y_idx)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(nll)s) if(!%(nll)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -990,7 +990,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -990,7 +990,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
{ {
Py_XDECREF(%(am)s); Py_XDECREF(%(am)s);
%(am)s = (PyArrayObject*) PyArray_SimpleNew(1, %(am)s = (PyArrayObject*) PyArray_SimpleNew(1,
PyArray_DIMS(%(y_idx)s), type_num_%(y_idx)s); PyArray_DIMS(%(y_idx)s), PyArray_TYPE((PyArrayObject*) py_%(y_idx)s));
if(!%(am)s) if(!%(am)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -1144,7 +1144,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -1144,7 +1144,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
if (NULL != %(dx)s) Py_XDECREF(%(dx)s); if (NULL != %(dx)s) Py_XDECREF(%(dx)s);
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2, %(dx)s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(%(sm)s), PyArray_DIMS(%(sm)s),
type_num_%(sm)s); PyArray_TYPE((PyArrayObject*) py_%(sm)s));
if(!%(dx)s) { if(!%(dx)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output"); "failed to alloc dx output");
......
...@@ -426,7 +426,6 @@ class TensorType(Type): ...@@ -426,7 +426,6 @@ class TensorType(Type):
check = "" check = ""
declaration = """ declaration = """
PyArrayObject* %(name)s; PyArrayObject* %(name)s;
int type_num_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
return declaration + check return declaration + check
...@@ -435,7 +434,6 @@ class TensorType(Type): ...@@ -435,7 +434,6 @@ class TensorType(Type):
"""Override `CLinkerType.c_init` """ """Override `CLinkerType.c_init` """
return """ return """
%(name)s = NULL; %(name)s = NULL;
type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
...@@ -455,7 +453,6 @@ class TensorType(Type): ...@@ -455,7 +453,6 @@ class TensorType(Type):
%(fail)s %(fail)s
} }
// We expect %(type_num)s // We expect %(type_num)s
type_num_%(name)s = PyArray_TYPE((PyArrayObject*) py_%(name)s);
if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
PyArrayObject * tmp = (PyArrayObject*) py_%(name)s; PyArrayObject * tmp = (PyArrayObject*) py_%(name)s;
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
...@@ -465,7 +462,7 @@ class TensorType(Type): ...@@ -465,7 +462,7 @@ class TensorType(Type):
"%%ld, %%ld, %%ld" "%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %%ld %%ld, %%ld.",
(long int) %(type_num)s, (long int) %(type_num)s,
(long int) type_num_%(name)s, (long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
(long int) PyArray_NDIM(tmp), (long int) PyArray_NDIM(tmp),
(long int) PyArray_NDIM(tmp) >= 3 ? (long int) PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1, PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1,
...@@ -484,17 +481,15 @@ class TensorType(Type): ...@@ -484,17 +481,15 @@ class TensorType(Type):
} }
// This is a TypeError to be consistent with DEBUG_MODE // This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container // Note: DEBUG_MODE also tells the name of the container
if (type_num_%(name)s != %(type_num)s) { if (PyArray_TYPE((PyArrayObject*) py_%(name)s) != %(type_num)s) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"expected type_num %%d (%(type_num)s) got %%d", "expected type_num %%d (%(type_num)s) got %%d",
%(type_num)s, type_num_%(name)s); %(type_num)s, PyArray_TYPE((PyArrayObject*) py_%(name)s));
%(fail)s %(fail)s
} }
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
else: else:
check = """ check = ""
type_num_%(name)s = PyArray_TYPE((PyArrayObject*) py_%(name)s);
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
return check + """ return check + """
%(name)s = (PyArrayObject*)(py_%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s); Py_XINCREF(%(name)s);
...@@ -526,13 +521,11 @@ class TensorType(Type): ...@@ -526,13 +521,11 @@ class TensorType(Type):
if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array of type %%ld " "c_sync: expected an aligned array, got non-aligned array of type %%ld"
"(%(type_num)s), got non-aligned array of type %%ld"
" with %%ld dimensions, with 3 last dims " " with %%ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld" "%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %%ld %%ld, %%ld.",
(long int) %(type_num)s, (long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
(long int) type_num_%(name)s,
(long int) PyArray_NDIM(%(name)s), (long int) PyArray_NDIM(%(name)s),
(long int) PyArray_NDIM(%(name)s) >= 3 ? (long int) PyArray_NDIM(%(name)s) >= 3 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1, PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1,
......
...@@ -76,7 +76,7 @@ class TypedListType(gof.Type): ...@@ -76,7 +76,7 @@ class TypedListType(gof.Type):
return True return True
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyListObject* %(name)s; PyListObject* %(name)s;
""" % dict(name=name) """ % dict(name=name)
...@@ -86,12 +86,16 @@ class TypedListType(gof.Type): ...@@ -86,12 +86,16 @@ class TypedListType(gof.Type):
%(name)s = NULL; %(name)s = NULL;
""" % dict(name=name) """ % dict(name=name)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ if check_input:
pre = """
if (!PyList_Check(py_%(name)s)) { if (!PyList_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a list"); PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s %(fail)s
} }"""
else:
pre = ""
return pre + """
%(name)s = (PyListObject*) (py_%(name)s); %(name)s = (PyListObject*) (py_%(name)s);
""" % dict(name=name, fail=sub['fail']) """ % dict(name=name, fail=sub['fail'])
...@@ -107,4 +111,4 @@ class TypedListType(gof.Type): ...@@ -107,4 +111,4 @@ class TypedListType(gof.Type):
return "" return ""
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论