提交 4b60641c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1884 from Hengjean/ShapeCheckInput

Shape check input
...@@ -141,8 +141,7 @@ The ``c_code`` method accepts variable names as arguments (``name``, ``inames``, ...@@ -141,8 +141,7 @@ The ``c_code`` method accepts variable names as arguments (``name``, ``inames``,
``onames``) and returns a C code fragment that computes the expression output. ``onames``) and returns a C code fragment that computes the expression output.
In case of error, the ``%(fail)s`` statement cleans up and returns properly. In case of error, the ``%(fail)s`` statement cleans up and returns properly.
The variables ``%(x)s`` and ``%(y)s`` are set up by the TensorType to be ``PyArrayObject`` pointers. The variables ``%(x)s`` and ``%(y)s`` are set up by the TensorType to be ``PyArrayObject`` pointers.
TensorType also set up ``dtype_%(x)s`` to be a typdef to the C type for ``x``, TensorType also set up ``dtype_%(x)s`` to be a typdef to the C type for ``x``.
``type_num_%(x)s`` is the corresponding NumPy type number.
In the first two lines of the C function, we make y point to a new array with In the first two lines of the C function, we make y point to a new array with
the correct size for the output. This is essentially simulating the line the correct size for the output. This is essentially simulating the line
......
...@@ -109,6 +109,8 @@ class OutputGuard(ViewOp): ...@@ -109,6 +109,8 @@ class OutputGuard(ViewOp):
""" """
destroy_map = {0: [0]} destroy_map = {0: [0]}
check_input = False
_output_guard = OutputGuard() _output_guard = OutputGuard()
...@@ -131,6 +133,8 @@ class DeepCopyOp(gof.Op): ...@@ -131,6 +133,8 @@ class DeepCopyOp(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = False
def __init__(self): def __init__(self):
pass pass
...@@ -169,6 +173,8 @@ class DeepCopyOp(gof.Op): ...@@ -169,6 +173,8 @@ class DeepCopyOp(gof.Op):
return () return ()
version.append((str(t), v)) version.append((str(t), v))
if version:
version.append(1)
return tuple(version) return tuple(version)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
...@@ -213,6 +219,8 @@ class Shape(gof.Op): ...@@ -213,6 +219,8 @@ class Shape(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = False
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -282,6 +290,9 @@ class Shape(gof.Op): ...@@ -282,6 +290,9 @@ class Shape(gof.Op):
return () return ()
version.append((str(t), v)) version.append((str(t), v))
if version:
version.append(1)
return tuple(version) return tuple(version)
...@@ -289,6 +300,7 @@ shape = Shape() ...@@ -289,6 +300,7 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
#pprint.assign(_shape, printing.MemberPrinter('shape')) #pprint.assign(_shape, printing.MemberPrinter('shape'))
class Shape_i(gof.Op): class Shape_i(gof.Op):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
...@@ -300,6 +312,8 @@ class Shape_i(gof.Op): ...@@ -300,6 +312,8 @@ class Shape_i(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = False
def __init__(self, i): def __init__(self, i):
self.i = i self.i = i
...@@ -345,6 +359,9 @@ class Shape_i(gof.Op): ...@@ -345,6 +359,9 @@ class Shape_i(gof.Op):
return () return ()
version.append((str(t), v)) version.append((str(t), v))
if version:
version.append(1)
return tuple(version) return tuple(version)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
...@@ -355,8 +372,14 @@ class Shape_i(gof.Op): ...@@ -355,8 +372,14 @@ class Shape_i(gof.Op):
itype = node.inputs[0].type.__class__ itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version: if itype in self.c_code_and_version:
sc = """
if (%(i)s>=PyArray_NDIM(%(iname)s)){
PyErr_SetString(PyExc_TypeError, "Number of dimensions lower than expected");
%(fail)s
}
""" % locals()
code, version = self.c_code_and_version[itype] code, version = self.c_code_and_version[itype]
return code % locals() return sc + code % locals()
# Else, no C code # Else, no C code
return super(Shape_i, self).c_code(node, name, inames, onames, sub) return super(Shape_i, self).c_code(node, name, inames, onames, sub)
...@@ -517,6 +540,8 @@ class Rebroadcast(gof.Op): ...@@ -517,6 +540,8 @@ class Rebroadcast(gof.Op):
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
check_input = False
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) self.axis = dict(axis)
for axis, broad in self.axis.iteritems(): for axis, broad in self.axis.iteritems():
...@@ -618,6 +643,8 @@ class Rebroadcast(gof.Op): ...@@ -618,6 +643,8 @@ class Rebroadcast(gof.Op):
return () return ()
version.append((str(t), v)) version.append((str(t), v))
if version:
version.append(1)
return tuple(version) return tuple(version)
......
...@@ -495,3 +495,10 @@ AddConfigVar('openmp_elemwise_minsize', ...@@ -495,3 +495,10 @@ AddConfigVar('openmp_elemwise_minsize',
IntParam(200000), IntParam(200000),
in_c_key=False, in_c_key=False,
) )
AddConfigVar('check_input',
"Specify if types should check their input in their C code. "
"It can be used to speed up compilation, reduce overhead"
"(particularly for scalars) and reduce the number of generated C"
"files.",
BoolParam(True))
...@@ -308,10 +308,18 @@ def get_nothing(r, name, sub): ...@@ -308,10 +308,18 @@ def get_nothing(r, name, sub):
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name""" """Wrapper around c_declare that declares py_name"""
if any([c != "output" and getattr(c.op, 'check_input',
config.check_input) for (c, _) in r.clients]) or (r.owner
and getattr(r.owner.op, 'check_input', True)):
c_declare = r.type.c_declare(name, sub, True)
else:
c_declare = r.type.c_declare(name, sub, False)
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" % locals() """ % locals()
return pre + r.type.c_declare(name, sub) return pre + c_declare
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
...@@ -325,20 +333,30 @@ def get_c_init(r, name, sub): ...@@ -325,20 +333,30 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage.""" """Wrapper around c_extract that initializes py_name from storage."""
if any([getattr(c.op, 'check_input', config.check_input) for (c, _) in
r.clients]):
c_extract = r.type.c_extract(name, sub, True)
else:
c_extract = r.type.c_extract(name, sub, False)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);} {Py_XINCREF(py_%(name)s);}
""" % locals() """ % locals()
return pre + r.type.c_extract(name, sub) return pre + c_extract
def get_c_extract_out(r, name, sub): def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage.""" """Wrapper around c_extract_out that initializes py_name from storage."""
c_extract = r.type.c_extract_out(name, sub,
getattr(r.owner.op, 'check_input', config.check_input))
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);} {Py_XINCREF(py_%(name)s);}
""" % locals() """ % locals()
return pre + r.type.c_extract_out(name, sub) return pre + c_extract
def get_c_cleanup(r, name, sub): def get_c_cleanup(r, name, sub):
......
...@@ -22,7 +22,7 @@ class TDouble(Type): ...@@ -22,7 +22,7 @@ class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return "double %(name)s; void* %(name)s_bad_thing;" % locals() return "double %(name)s; void* %(name)s_bad_thing;" % locals()
def c_init(self, name, sub): def c_init(self, name, sub):
...@@ -35,7 +35,7 @@ class TDouble(Type): ...@@ -35,7 +35,7 @@ class TDouble(Type):
def c_literal(self, data): def c_literal(self, data):
return str(data) return str(data)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "not a double!"); PyErr_SetString(PyExc_TypeError, "not a double!");
......
...@@ -44,7 +44,7 @@ class CLinkerType(CLinkerObject): ...@@ -44,7 +44,7 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_literal", type(self), self.__class__.__name__) raise MethodNotDefined("c_literal", type(self), self.__class__.__name__)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
"""Required: Return c code to declare variables that will be """Required: Return c code to declare variables that will be
instantiated by `c_extract`. instantiated by `c_extract`.
...@@ -96,7 +96,7 @@ class CLinkerType(CLinkerObject): ...@@ -96,7 +96,7 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_init", type(self), self.__class__.__name__) raise MethodNotDefined("c_init", type(self), self.__class__.__name__)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
"""Required: Return c code to extract a PyObject * instance. """Required: Return c code to extract a PyObject * instance.
The code returned from this function must be templated using The code returned from this function must be templated using
...@@ -137,7 +137,7 @@ class CLinkerType(CLinkerObject): ...@@ -137,7 +137,7 @@ class CLinkerType(CLinkerObject):
""" """
raise MethodNotDefined("c_extract", type(self), self.__class__.__name__) raise MethodNotDefined("c_extract", type(self), self.__class__.__name__)
def c_extract_out(self, name, sub): def c_extract_out(self, name, sub, check_input=True):
"""Optional: C code to extract a PyObject * instance. """Optional: C code to extract a PyObject * instance.
Unlike c_extract, c_extract_out has to accept Py_None, Unlike c_extract, c_extract_out has to accept Py_None,
...@@ -155,7 +155,7 @@ class CLinkerType(CLinkerObject): ...@@ -155,7 +155,7 @@ class CLinkerType(CLinkerObject):
""" % dict( """ % dict(
name=name, name=name,
c_init_code=self.c_init(name, sub), c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub)) c_extract_code=self.c_extract(name, sub, check_input))
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
"""Optional: Return c code to clean up after `c_extract`. """Optional: Return c code to clean up after `c_extract`.
...@@ -434,7 +434,7 @@ class Generic(SingletonType): ...@@ -434,7 +434,7 @@ class Generic(SingletonType):
def is_valid_value(self, a): def is_valid_value(self, a):
return True return True
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyObject* %(name)s; PyObject* %(name)s;
""" % locals() """ % locals()
...@@ -444,7 +444,7 @@ class Generic(SingletonType): ...@@ -444,7 +444,7 @@ class Generic(SingletonType):
%(name)s = NULL; %(name)s = NULL;
""" % locals() """ % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ return """
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
%(name)s = py_%(name)s; %(name)s = py_%(name)s;
......
...@@ -55,6 +55,8 @@ class HostFromGpu(GpuOp): ...@@ -55,6 +55,8 @@ class HostFromGpu(GpuOp):
""" """
Implement the transfer from gpu to the cpu. Implement the transfer from gpu to the cpu.
""" """
check_input = False
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -104,7 +106,7 @@ class HostFromGpu(GpuOp): ...@@ -104,7 +106,7 @@ class HostFromGpu(GpuOp):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
host_from_gpu = HostFromGpu() host_from_gpu = HostFromGpu()
...@@ -112,6 +114,8 @@ class GpuFromHost(GpuOp): ...@@ -112,6 +114,8 @@ class GpuFromHost(GpuOp):
""" """
Implement the transfer from cpu to the gpu. Implement the transfer from cpu to the gpu.
""" """
check_input = False
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -168,7 +172,7 @@ class GpuFromHost(GpuOp): ...@@ -168,7 +172,7 @@ class GpuFromHost(GpuOp):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
gpu_from_host = GpuFromHost() gpu_from_host = GpuFromHost()
......
...@@ -274,13 +274,13 @@ class CudaNdarrayType(Type): ...@@ -274,13 +274,13 @@ class CudaNdarrayType(Type):
return str(self) return str(self)
#"CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) #"CudaNdarrayType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ CudaNdarray * %(name)s;""" % locals() return """ CudaNdarray * %(name)s;""" % locals()
def c_init(self, name, sub): def c_init(self, name, sub):
return "%(name)s = NULL;" % locals() return "%(name)s = NULL;" % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
sio = StringIO() sio = StringIO()
fail = sub['fail'] fail = sub['fail']
nd = self.ndim nd = self.ndim
...@@ -293,6 +293,9 @@ class CudaNdarrayType(Type): ...@@ -293,6 +293,9 @@ class CudaNdarrayType(Type):
//fprintf(stderr, "c_extract CNDA object w refcnt %%p %%i\\n", py_%(name)s, (py_%(name)s->ob_refcnt)); //fprintf(stderr, "c_extract CNDA object w refcnt %%p %%i\\n", py_%(name)s, (py_%(name)s->ob_refcnt));
%(name)s = (CudaNdarray*)py_%(name)s; %(name)s = (CudaNdarray*)py_%(name)s;
//std::cerr << "c_extract " << %(name)s << '\\n'; //std::cerr << "c_extract " << %(name)s << '\\n';
""" % locals()
if(check_input):
print >> sio, """
if (%(name)s->nd != %(nd)s) if (%(name)s->nd != %(nd)s)
{ {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
...@@ -348,6 +351,12 @@ class CudaNdarrayType(Type): ...@@ -348,6 +351,12 @@ class CudaNdarrayType(Type):
} }
//std::cerr << "c_extract done " << %(name)s << '\\n'; //std::cerr << "c_extract done " << %(name)s << '\\n';
""" % locals() """ % locals()
else:
print >> sio, """
assert(%(name)s);
Py_INCREF(py_%(name)s);
}
""" % locals()
#print sio.getvalue() #print sio.getvalue()
return sio.getvalue() return sio.getvalue()
......
...@@ -155,7 +155,7 @@ class GpuArrayType(Type): ...@@ -155,7 +155,7 @@ class GpuArrayType(Type):
else: else:
return numpy.dtype(self.dtype).itemsize return numpy.dtype(self.dtype).itemsize
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyGpuArrayObject *%(name)s; PyGpuArrayObject *%(name)s;
""" % locals() """ % locals()
...@@ -163,7 +163,7 @@ class GpuArrayType(Type): ...@@ -163,7 +163,7 @@ class GpuArrayType(Type):
def c_init(self, name, sub): def c_init(self, name, sub):
return "%s = NULL;" % (name,) return "%s = NULL;" % (name,)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
# TODO I don't check broadcast stuff for now. # TODO I don't check broadcast stuff for now.
return """ return """
%(name)s = NULL; %(name)s = NULL;
......
...@@ -59,7 +59,14 @@ class MultinomialFromUniform(Op): ...@@ -59,7 +59,14 @@ class MultinomialFromUniform(Op):
def c_code(self, node, name, ins, outs, sub): def c_code(self, node, name, ins, outs, sub):
(pvals, unis) = ins (pvals, unis) = ins
(z,) = outs (z,) = outs
if self.odtype == 'auto':
t = "PyArray_TYPE((PyArrayObject*) py_%(pvals)s)" % locals()
else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'):
t = t.replace('theano_complex', 'NPY_COMPLEX')
else:
t = t.upper()
fail = sub['fail'] fail = sub['fail']
return """ return """
if (PyArray_NDIM(%(pvals)s) != 2) if (PyArray_NDIM(%(pvals)s) != 2)
...@@ -87,7 +94,7 @@ class MultinomialFromUniform(Op): ...@@ -87,7 +94,7 @@ class MultinomialFromUniform(Op):
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_ZEROS(2, %(z)s = (PyArrayObject*) PyArray_ZEROS(2,
PyArray_DIMS(%(pvals)s), PyArray_DIMS(%(pvals)s),
type_num_%(z)s, %(t)s,
0); 0);
if (!%(z)s) if (!%(z)s)
{ {
......
...@@ -335,7 +335,7 @@ class Images2Neibs(Op): ...@@ -335,7 +335,7 @@ class Images2Neibs(Op):
%(z)s = (PyArrayObject*) PyArray_EMPTY(2, %(z)s = (PyArrayObject*) PyArray_EMPTY(2,
dims, dims,
type_num_%(ten4)s, PyArray_TYPE((PyArrayObject*) py_%(ten4)s),
0); 0);
if (!%(z)s) if (!%(z)s)
......
...@@ -254,32 +254,42 @@ class Scalar(Type): ...@@ -254,32 +254,42 @@ class Scalar(Type):
raise NotImplementedError("No literal for complex values.") raise NotImplementedError("No literal for complex values.")
return str(data) return str(data)
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ if(check_input):
%(dtype)s %(name)s; pre = """
typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead. typedef %(dtype)s %(name)s_dtype; // Deprecated use dtype_%(name)s instead.
typedef %(dtype)s dtype_%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1]) """ % dict(name=name, dtype=self.dtype_specs()[1])
else:
pre = ""
return pre + """
%(dtype)s %(name)s;
""" % dict(name=name, dtype=self.dtype_specs()[1])
def c_init(self, name, sub): def c_init(self, name, sub):
return """ return """
%(name)s = 0; %(name)s = 0;
""" % locals() """ % locals()
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ if(check_input):
pre = """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Scalar check failed (%(dtype)s)"); "Scalar check failed (%(dtype)s)");
%(fail)s %(fail)s
} }
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
""" % dict(sub, """ % dict(sub,
name=name, name=name,
dtype=specs[1], dtype=specs[1],
pyarr_type='Py%sArrType_Type' % specs[2]) pyarr_type='Py%sArrType_Type' % specs[2])
else:
pre = ""
return pre + """
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
""" % dict(sub, name=name)
def c_sync(self, name, sub): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
...@@ -452,7 +462,7 @@ class Scalar(Type): ...@@ -452,7 +462,7 @@ class Scalar(Type):
return ["import_array();"] return ["import_array();"]
def c_code_cache_version(self): def c_code_cache_version(self):
return (12, numpy.__version__) return (13, numpy.__version__)
def get_shape_info(self, obj): def get_shape_info(self, obj):
return obj.itemsize return obj.itemsize
......
...@@ -2425,7 +2425,7 @@ class Alloc(gof.Op): ...@@ -2425,7 +2425,7 @@ class Alloc(gof.Op):
{ {
Py_XDECREF(%(zz)s); Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s, %(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s,
shape, type_num_%(vv)s); shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s));
if (!%(zz)s) if (!%(zz)s)
{ {
PyErr_SetString(PyExc_MemoryError, "alloc failed"); PyErr_SetString(PyExc_MemoryError, "alloc failed");
...@@ -3261,6 +3261,8 @@ class Join(Op): ...@@ -3261,6 +3261,8 @@ class Join(Op):
join(2, x, y, z) # WRONG: the axis has to be an index into the shape join(2, x, y, z) # WRONG: the axis has to be an index into the shape
join(0, x, u) # WRONG: joined tensors must have the same rank join(0, x, u) # WRONG: joined tensors must have the same rank
""" """
check_input = False
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -3372,14 +3374,14 @@ class Join(Op): ...@@ -3372,14 +3374,14 @@ class Join(Op):
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:] axis, tensors = inputs[0], inputs[1:]
l = len(tensors) l = len(tensors)
out, = outputs out, = outputs
fail = sub['fail'] fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1]
code = """ code = """
PyObject* list = PyList_New(%(l)s); PyObject* list = PyList_New(%(l)s);
""" % locals() """ % locals()
...@@ -3392,7 +3394,7 @@ class Join(Op): ...@@ -3392,7 +3394,7 @@ class Join(Op):
//PyObject* PyArray_Concatenate(PyObject* obj, int axis) //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, %(out)s = (PyArrayObject *)PyArray_Concatenate(list,
((dtype_%(axis)s *)PyArray_DATA(%(axis)s))[0]); ((%(adtype)s *)PyArray_DATA(%(axis)s))[0]);
Py_DECREF(list); Py_DECREF(list);
if(!%(out)s){ if(!%(out)s){
...@@ -3685,6 +3687,8 @@ class Reshape(Op): ...@@ -3685,6 +3687,8 @@ class Reshape(Op):
known at graph build time.""" known at graph build time."""
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0] view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
check_input = False
def __init__(self, ndim, name=None): def __init__(self, ndim, name=None):
self.ndim = ndim self.ndim = ndim
self.name = name self.name = name
...@@ -3814,13 +3818,14 @@ class Reshape(Op): ...@@ -3814,13 +3818,14 @@ class Reshape(Op):
return [tuple(oshape)] return [tuple(oshape)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable): if isinstance(node.inputs[0], TensorVariable):
x, shp = inputs x, shp = inputs
z, = outputs z, = outputs
new_ndim = self.ndim new_ndim = self.ndim
sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub['fail'] fail = sub['fail']
return """ return """
assert (PyArray_NDIM(%(shp)s) == 1); assert (PyArray_NDIM(%(shp)s) == 1);
...@@ -3834,7 +3839,7 @@ class Reshape(Op): ...@@ -3834,7 +3839,7 @@ class Reshape(Op):
// -- int* dtype. The compiler will explicitly upcast it, but // -- int* dtype. The compiler will explicitly upcast it, but
// -- will err if this will downcast. This could happen if the // -- will err if this will downcast. This could happen if the
// -- user pass an int64 dtype, but npy_intp endup being int32. // -- user pass an int64 dtype, but npy_intp endup being int32.
new_dims[ii] = ((dtype_%(shp)s*)( new_dims[ii] = ((%(sdtype)s*)(
PyArray_BYTES(%(shp)s) + PyArray_BYTES(%(shp)s) +
ii * PyArray_STRIDES(%(shp)s)[0]))[0]; ii * PyArray_STRIDES(%(shp)s)[0]))[0];
} }
......
...@@ -1026,7 +1026,7 @@ class Gemm(GemmRelated): ...@@ -1026,7 +1026,7 @@ class Gemm(GemmRelated):
dims[0] = PyArray_DIMS(%(_z)s)[0]; dims[0] = PyArray_DIMS(%(_z)s)[0];
dims[1] = PyArray_DIMS(%(_z)s)[1]; dims[1] = PyArray_DIMS(%(_z)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_z)s); PyArray_TYPE((PyArrayObject*) py_%(_z)s));
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -1627,7 +1627,7 @@ class Dot22(GemmRelated): ...@@ -1627,7 +1627,7 @@ class Dot22(GemmRelated):
dims[0] = PyArray_DIMS(%(_x)s)[0]; dims[0] = PyArray_DIMS(%(_x)s)[0];
dims[1] = PyArray_DIMS(%(_y)s)[1]; dims[1] = PyArray_DIMS(%(_y)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
type_num_%(_x)s); PyArray_TYPE((PyArrayObject*) py_%(_x)s));
//fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]); //fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
......
...@@ -353,7 +353,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -353,7 +353,7 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ {
if (%(zz)s) Py_XDECREF(%(zz)s); if (%(zz)s) Py_XDECREF(%(zz)s);
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1, %(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(aa)s), type_num_%(aa)s); PyArray_DIMS(%(aa)s), PyArray_TYPE((PyArrayObject*) py_%(aa)s));
if(!%(zz)s) { if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output"); "failed to alloc gemv output");
......
...@@ -95,6 +95,8 @@ class DimShuffle(Op): ...@@ -95,6 +95,8 @@ class DimShuffle(Op):
Adding, subtracting dimensions can be done with reshape. Adding, subtracting dimensions can be done with reshape.
""" """
check_input = False
def __init__(self, input_broadcastable, new_order, inplace=False): def __init__(self, input_broadcastable, new_order, inplace=False):
""" """
Usage: DimShuffle(input_broadcastable, new_order, inplace = False) Usage: DimShuffle(input_broadcastable, new_order, inplace = False)
...@@ -369,7 +371,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s); ...@@ -369,7 +371,7 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
return full_code % dict(locals(), **sub) return full_code % dict(locals(), **sub)
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
......
...@@ -121,7 +121,9 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -121,7 +121,9 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
created, otherwise it will be c order. created, otherwise it will be c order.
""" """
type = dtype.upper()
if type.startswith('THEANO_COMPLEX'):
type = type.replace('THEANO_COMPLEX', 'NPY_COMPLEX')
nd = len(loop_orders[0]) nd = len(loop_orders[0])
init_dims = "" init_dims = ""
# For each dimension, the tensors are either all broadcasted, in # For each dimension, the tensors are either all broadcasted, in
...@@ -142,7 +144,6 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -142,7 +144,6 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest # contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS. # stride. Right now, it is allocated to be C_CONTIGUOUS.
return """ return """
{ {
npy_intp dims[%(nd)s]; npy_intp dims[%(nd)s];
...@@ -150,7 +151,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -150,7 +151,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
%(init_dims)s %(init_dims)s
if (!%(olv)s) { if (!%(olv)s) {
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims,
type_num_%(olv)s, %(type)s,
%(fortran)s); %(fortran)s);
} }
else { else {
...@@ -162,7 +163,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'): ...@@ -162,7 +163,7 @@ def make_alloc(loop_orders, dtype, sub, fortran='0'):
// If we can't resize the ndarray we have we can allocate a new one. // If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear(); PyErr_Clear();
Py_XDECREF(%(olv)s); Py_XDECREF(%(olv)s);
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, type_num_%(olv)s, 0); %(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(type)s, 0);
} }
} }
if (!%(olv)s) { if (!%(olv)s) {
......
...@@ -68,13 +68,13 @@ class CumsumOp(theano.Op): ...@@ -68,13 +68,13 @@ class CumsumOp(theano.Op):
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, NPY_MAXDIMS, type_num_%(x)s, %(z)s); PyArray_CumSum(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -83,13 +83,13 @@ class CumsumOp(theano.Op): ...@@ -83,13 +83,13 @@ class CumsumOp(theano.Op):
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) )) if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) ))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumSum(%(x)s, %(axis)s, type_num_%(x)s, %(z)s); PyArray_CumSum(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -177,13 +177,13 @@ class CumprodOp(theano.Op): ...@@ -177,13 +177,13 @@ class CumprodOp(theano.Op):
if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(!(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, NPY_MAXDIMS, type_num_%(x)s, %(z)s); PyArray_CumProd(%(x)s, NPY_MAXDIMS, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
...@@ -192,13 +192,13 @@ class CumprodOp(theano.Op): ...@@ -192,13 +192,13 @@ class CumprodOp(theano.Op):
if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) )) if(!(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) ))
{ {
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), type_num_%(x)s); %(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
} }
if (!%(z)s) if (!%(z)s)
%(fail)s; %(fail)s;
{ {
PyArray_CumProd(%(x)s, %(axis)s, type_num_%(x)s, %(z)s); PyArray_CumProd(%(x)s, %(axis)s, PyArray_TYPE((PyArrayObject*) py_%(x)s), %(z)s);
Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s. Py_XDECREF(%(z)s); // Because PyArray_CumSum returns a newly created reference on %(z)s.
} }
""" % locals() """ % locals()
......
...@@ -148,7 +148,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -148,7 +148,7 @@ class SoftmaxWithBias(gof.Op):
{ {
if (NULL != %(sm)s) Py_XDECREF(%(sm)s); if (NULL != %(sm)s) Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s); PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(sm)s) { if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output"); "failed to alloc sm output");
...@@ -342,7 +342,7 @@ class SoftmaxGrad(gof.Op): ...@@ -342,7 +342,7 @@ class SoftmaxGrad(gof.Op):
Py_XDECREF(%(dx)s); Py_XDECREF(%(dx)s);
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2, %(dx)s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(%(sm)s), PyArray_DIMS(%(sm)s),
type_num_%(sm)s); PyArray_TYPE((PyArrayObject*) py_%(sm)s));
if (!%(dx)s) if (!%(dx)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -463,7 +463,7 @@ class Softmax(gof.Op): ...@@ -463,7 +463,7 @@ class Softmax(gof.Op):
{ {
Py_XDECREF(%(sm)s); Py_XDECREF(%(sm)s);
%(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s),
type_num_%(x)s); PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(sm)s) { if(!%(sm)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc sm output"); "failed to alloc sm output");
...@@ -977,7 +977,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -977,7 +977,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
{ {
if (NULL != %(nll)s) Py_XDECREF(%(nll)s); if (NULL != %(nll)s) Py_XDECREF(%(nll)s);
%(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1,
PyArray_DIMS(%(y_idx)s), type_num_%(x)s); PyArray_DIMS(%(y_idx)s), PyArray_TYPE((PyArrayObject*) py_%(x)s));
if(!%(nll)s) if(!%(nll)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -990,7 +990,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -990,7 +990,7 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
{ {
Py_XDECREF(%(am)s); Py_XDECREF(%(am)s);
%(am)s = (PyArrayObject*) PyArray_SimpleNew(1, %(am)s = (PyArrayObject*) PyArray_SimpleNew(1,
PyArray_DIMS(%(y_idx)s), type_num_%(y_idx)s); PyArray_DIMS(%(y_idx)s), PyArray_TYPE((PyArrayObject*) py_%(y_idx)s));
if(!%(am)s) if(!%(am)s)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
...@@ -1144,7 +1144,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -1144,7 +1144,7 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
if (NULL != %(dx)s) Py_XDECREF(%(dx)s); if (NULL != %(dx)s) Py_XDECREF(%(dx)s);
%(dx)s = (PyArrayObject*) PyArray_SimpleNew(2, %(dx)s = (PyArrayObject*) PyArray_SimpleNew(2,
PyArray_DIMS(%(sm)s), PyArray_DIMS(%(sm)s),
type_num_%(sm)s); PyArray_TYPE((PyArrayObject*) py_%(sm)s));
if(!%(dx)s) { if(!%(dx)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc dx output"); "failed to alloc dx output");
......
...@@ -1363,6 +1363,8 @@ class Assert(T.Op): ...@@ -1363,6 +1363,8 @@ class Assert(T.Op):
""" """
view_map = {0: [0]} view_map = {0: [0]}
check_input = False
def __init__(self, msg="Theano Assert failed!"): def __init__(self, msg="Theano Assert failed!"):
self.msg = msg self.msg = msg
...@@ -1415,7 +1417,7 @@ class Assert(T.Op): ...@@ -1415,7 +1417,7 @@ class Assert(T.Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, 1) return (3, 0)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
......
...@@ -277,7 +277,7 @@ class Subtensor(Op): ...@@ -277,7 +277,7 @@ class Subtensor(Op):
e_subslice = 'nested slicing is not supported' e_subslice = 'nested slicing is not supported'
e_indextype = "Invalid index type or slice for Subtensor" e_indextype = "Invalid index type or slice for Subtensor"
debug = 0 debug = 0
check_input = False
view_map = {0: [0]} view_map = {0: [0]}
@staticmethod @staticmethod
...@@ -892,7 +892,7 @@ class Subtensor(Op): ...@@ -892,7 +892,7 @@ class Subtensor(Op):
# have a versioned version of this op's C code. # have a versioned version of this op's C code.
if len(hv) == 0: if len(hv) == 0:
return () return ()
return (2, hv) return (3, hv)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# Subtensor is not differentiable wrt to its indices, therefore we # Subtensor is not differentiable wrt to its indices, therefore we
...@@ -1074,6 +1074,8 @@ class IncSubtensor(Op): ...@@ -1074,6 +1074,8 @@ class IncSubtensor(Op):
of incrementing it by that value. of incrementing it by that value.
""" """
check_input = False
def __init__(self, idx_list, inplace=False, set_instead_of_inc=False, def __init__(self, idx_list, inplace=False, set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None): destroyhandler_tolerate_aliased=None):
if destroyhandler_tolerate_aliased is None: if destroyhandler_tolerate_aliased is None:
......
...@@ -416,24 +416,30 @@ class TensorType(Type): ...@@ -416,24 +416,30 @@ class TensorType(Type):
return str(self) return str(self)
#"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable)) #"TensorType{%s, %s}" % (str(self.dtype), str(self.broadcastable))
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
"""Override `CLinkerType.c_declare` """ """Override `CLinkerType.c_declare` """
return """ if(check_input):
PyArrayObject* %(name)s; check = """
int type_num_%(name)s;
typedef %(dtype)s dtype_%(name)s; typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """ % dict(sub, name=name, dtype=self.dtype_specs()[1])
else:
check = ""
declaration = """
PyArrayObject* %(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
return declaration + check
def c_init(self, name, sub): def c_init(self, name, sub):
"""Override `CLinkerType.c_init` """ """Override `CLinkerType.c_init` """
return """ return """
%(name)s = NULL; %(name)s = NULL;
type_num_%(name)s = %(type_num)s;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
"""Override `CLinkerType.c_extract` """ """Override `CLinkerType.c_extract` """
return """ if(check_input):
check = """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops // We can either fail here or set %(name)s to NULL and rely on Ops
...@@ -447,7 +453,6 @@ class TensorType(Type): ...@@ -447,7 +453,6 @@ class TensorType(Type):
%(fail)s %(fail)s
} }
// We expect %(type_num)s // We expect %(type_num)s
type_num_%(name)s = PyArray_TYPE((PyArrayObject*) py_%(name)s);
if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
PyArrayObject * tmp = (PyArrayObject*) py_%(name)s; PyArrayObject * tmp = (PyArrayObject*) py_%(name)s;
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
...@@ -457,7 +462,7 @@ class TensorType(Type): ...@@ -457,7 +462,7 @@ class TensorType(Type):
"%%ld, %%ld, %%ld" "%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %%ld %%ld, %%ld.",
(long int) %(type_num)s, (long int) %(type_num)s,
(long int) type_num_%(name)s, (long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
(long int) PyArray_NDIM(tmp), (long int) PyArray_NDIM(tmp),
(long int) PyArray_NDIM(tmp) >= 3 ? (long int) PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1, PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1,
...@@ -476,12 +481,16 @@ class TensorType(Type): ...@@ -476,12 +481,16 @@ class TensorType(Type):
} }
// This is a TypeError to be consistent with DEBUG_MODE // This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container // Note: DEBUG_MODE also tells the name of the container
if (type_num_%(name)s != %(type_num)s) { if (PyArray_TYPE((PyArrayObject*) py_%(name)s) != %(type_num)s) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"expected type_num %%d (%(type_num)s) got %%d", "expected type_num %%d (%(type_num)s) got %%d",
%(type_num)s, type_num_%(name)s); %(type_num)s, PyArray_TYPE((PyArrayObject*) py_%(name)s));
%(fail)s %(fail)s
} }
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
else:
check = ""
return check + """
%(name)s = (PyArrayObject*)(py_%(name)s); %(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s); Py_XINCREF(%(name)s);
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """ % dict(sub, name=name, type_num=self.dtype_specs()[2])
...@@ -512,13 +521,11 @@ class TensorType(Type): ...@@ -512,13 +521,11 @@ class TensorType(Type):
if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array of type %%ld " "c_sync: expected an aligned array, got non-aligned array of type %%ld"
"(%(type_num)s), got non-aligned array of type %%ld"
" with %%ld dimensions, with 3 last dims " " with %%ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld" "%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %%ld %%ld, %%ld.",
(long int) %(type_num)s, (long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
(long int) type_num_%(name)s,
(long int) PyArray_NDIM(%(name)s), (long int) PyArray_NDIM(%(name)s),
(long int) PyArray_NDIM(%(name)s) >= 3 ? (long int) PyArray_NDIM(%(name)s) >= 3 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1, PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1,
......
...@@ -232,7 +232,7 @@ class T_extending(unittest.TestCase): ...@@ -232,7 +232,7 @@ class T_extending(unittest.TestCase):
div = BinaryDoubleOp(name = 'div', div = BinaryDoubleOp(name = 'div',
fn = lambda x, y: x / y) fn = lambda x, y: x / y)
def c_declare(name, sub): def c_declare(name, sub, check_input=True):
return """ return """
double %(name)s; double %(name)s;
""" % dict(name = name) """ % dict(name = name)
...@@ -245,19 +245,20 @@ class T_extending(unittest.TestCase): ...@@ -245,19 +245,20 @@ class T_extending(unittest.TestCase):
""" % dict(name = name) """ % dict(name = name)
double.c_init = c_init double.c_init = c_init
def c_extract(name, sub, check_input=True):
if(check_input):
def c_extract(name, sub): pre = """
return """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s %(fail)s
} }""" % dict(name = name, fail = sub['fail'])
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(name = name, fail = sub['fail']) """ % dict(name = name, fail = sub['fail'])
double.c_extract = c_extract double.c_extract = c_extract
def c_sync( name, sub): def c_sync( name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
...@@ -298,7 +299,7 @@ class T_extending(unittest.TestCase): ...@@ -298,7 +299,7 @@ class T_extending(unittest.TestCase):
def __str__(self): def __str__(self):
return "double" return "double"
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
double %(name)s; double %(name)s;
""" % dict(name = name) """ % dict(name = name)
...@@ -308,14 +309,19 @@ class T_extending(unittest.TestCase): ...@@ -308,14 +309,19 @@ class T_extending(unittest.TestCase):
%(name)s = 0.0; %(name)s = 0.0;
""" % dict(name = name) """ % dict(name = name)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ if(check_input):
pre = """
if (!PyFloat_Check(py_%(name)s)) { if (!PyFloat_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a float"); PyErr_SetString(PyExc_TypeError, "expected a float");
%(fail)s %(fail)s
} }
""" % dict(sub, name=name)
else:
pre = ""
return pre + """
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
""" % dict(sub, name = name) """ % dict(sub, name=name)
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
......
...@@ -76,7 +76,7 @@ class TypedListType(gof.Type): ...@@ -76,7 +76,7 @@ class TypedListType(gof.Type):
return True return True
def c_declare(self, name, sub): def c_declare(self, name, sub, check_input=True):
return """ return """
PyListObject* %(name)s; PyListObject* %(name)s;
""" % dict(name=name) """ % dict(name=name)
...@@ -86,12 +86,16 @@ class TypedListType(gof.Type): ...@@ -86,12 +86,16 @@ class TypedListType(gof.Type):
%(name)s = NULL; %(name)s = NULL;
""" % dict(name=name) """ % dict(name=name)
def c_extract(self, name, sub): def c_extract(self, name, sub, check_input=True):
return """ if check_input:
pre = """
if (!PyList_Check(py_%(name)s)) { if (!PyList_Check(py_%(name)s)) {
PyErr_SetString(PyExc_TypeError, "expected a list"); PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s %(fail)s
} }"""
else:
pre = ""
return pre + """
%(name)s = (PyListObject*) (py_%(name)s); %(name)s = (PyListObject*) (py_%(name)s);
""" % dict(name=name, fail=sub['fail']) """ % dict(name=name, fail=sub['fail'])
...@@ -107,4 +111,4 @@ class TypedListType(gof.Type): ...@@ -107,4 +111,4 @@ class TypedListType(gof.Type):
return "" return ""
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论