提交 cd80466d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

corrected nasty bug involving not doing Py_INCREF on Py_None

上级 29adf634
...@@ -189,9 +189,9 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -189,9 +189,9 @@ def struct_gen(args, struct_builders, blocks, sub):
PyObject* err_msg = NULL; PyObject* err_msg = NULL;
PyObject* err_traceback = NULL; PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback); PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) err_type = Py_None; if (!err_type) {err_type = Py_None; Py_XINCREF(Py_None);}
if (!err_msg) err_msg = Py_None; if (!err_msg) {err_msg = Py_None; Py_XINCREF(Py_None);}
if (!err_traceback) err_traceback = Py_None; if (!err_traceback) {err_traceback = Py_None; Py_XINCREF(Py_None);}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0); PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1); PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2); PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
...@@ -265,6 +265,7 @@ def get_c_declare(r, name, sub): ...@@ -265,6 +265,7 @@ def get_c_declare(r, name, sub):
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
pre = "" """ pre = "" """
py_%(name)s = Py_None; py_%(name)s = Py_None;
Py_XINCREF(py_%(name)s);
""" % locals() """ % locals()
return pre + r.type.c_init(name, sub) return pre + r.type.c_init(name, sub)
...@@ -711,8 +712,11 @@ class CLinker(link.Linker): ...@@ -711,8 +712,11 @@ class CLinker(link.Linker):
void %(struct_name)s_destructor(void* executor, void* self) { void %(struct_name)s_destructor(void* executor, void* self) {
//printf("doing cleanup\\n"); //printf("doing cleanup\\n");
//fflush(stdout);
((%(struct_name)s*)self)->cleanup(); ((%(struct_name)s*)self)->cleanup();
free(self); free(self);
//printf("done cleanup\\n");
//fflush(stdout);
} }
""" % dict(struct_name = self.struct_name) """ % dict(struct_name = self.struct_name)
......
...@@ -171,7 +171,6 @@ class Tensor(Type): ...@@ -171,7 +171,6 @@ class Tensor(Type):
// with nasty segfaults, so this is public service. // with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None"); PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)s %(fail)s
//%(name)s = NULL;
} }
else if (!PyArray_Check(py_%(name)s)) { else if (!PyArray_Check(py_%(name)s)) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray"); PyErr_SetString(PyExc_ValueError, "expected an ndarray");
...@@ -196,15 +195,14 @@ class Tensor(Type): ...@@ -196,15 +195,14 @@ class Tensor(Type):
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
if (!%(name)s) {
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
if (!%(name)s) {
py_%(name)s = Py_None; py_%(name)s = Py_None;
} }
else if ((void*)py_%(name)s != (void*)%(name)s) { else if ((void*)py_%(name)s != (void*)%(name)s) {
Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)%(name)s; py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s);
} }
Py_XINCREF(py_%(name)s);
""" % locals() """ % locals()
def c_headers(self): def c_headers(self):
...@@ -606,11 +604,11 @@ tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh') ...@@ -606,11 +604,11 @@ tanh, tanh_inplace = _elemwise(scal.tanh, 'tanh')
fill, fill_inplace = _elemwise(scal.second, 'fill') fill, fill_inplace = _elemwise(scal.second, 'fill')
def ones_like(model): def ones_like(model):
return Ones(model.type.ndim)(shape(model)) #return Ones(model.type.ndim)(shape(model))
#return fill(model, 1.0) return fill(model, 1.0)
def zeros_like(model): def zeros_like(model):
return Zeros(model.type.ndim)(shape(model)) #return Zeros(model.type.ndim)(shape(model))
#return fill(model, 0.0) return fill(model, 0.0)
class Filler(gof.Op): class Filler(gof.Op):
def __init__(self, value, ndim, dtype = 'float64'): def __init__(self, value, ndim, dtype = 'float64'):
...@@ -916,6 +914,21 @@ class SetSubtensor(Subtensor): ...@@ -916,6 +914,21 @@ class SetSubtensor(Subtensor):
out[0] = x out[0] = x
class MakeVector(Op):
def __init__(self, stype):
self.stype = stype
def make_node(self, *inputs):
assert all(a.type == self.stype for a in inputs)
return Apply(self, inputs, [Tensor(broadcastable = (False,),
dtype = self.stype.dtype)()])
def perform(self, inputs, (out,)):
return numpy.asarray([i[0] for i in inputs])
def grad(self, inputs, (gout,)):
return [None]*len(inputs)
make_lvector = MakeVector(lscalar)
class VerticalStack(Op): class VerticalStack(Op):
""" """
Vertically stack two L{Tensor}s. Vertically stack two L{Tensor}s.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论