提交 5ad51d46 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Allocate output memory if size is wrong

上级 589f0869
...@@ -1847,6 +1847,9 @@ class StructuredDotGradCSC(gof.Op): ...@@ -1847,6 +1847,9 @@ class StructuredDotGradCSC(gof.Op):
g_a_data[i_idx] = dot_val g_a_data[i_idx] = dot_val
out[0] = g_a_data out[0] = g_a_data
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub): def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
if node.inputs[2].type.dtype in ('complex64', 'complex128'): if node.inputs[2].type.dtype in ('complex64', 'complex128'):
...@@ -1870,17 +1873,13 @@ class StructuredDotGradCSC(gof.Op): ...@@ -1870,17 +1873,13 @@ class StructuredDotGradCSC(gof.Op):
if( %(_d)s->dimensions[1] != %(_g)s->dimensions[1]) if( %(_d)s->dimensions[1] != %(_g)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "d and g have different numbers of columns"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "d and g have different numbers of columns"); %(fail)s;}
if (!%(_zout)s) if (!%(_zout)s
|| (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]))
{ {
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_g)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_g)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s;
}
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
npy_intp nnz = %(_indices)s->dimensions[0]; npy_intp nnz = %(_indices)s->dimensions[0];
npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this npy_intp N = %(_indptr)s->dimensions[0]-1; //TODO: error checking with this
...@@ -1971,6 +1970,9 @@ class StructuredDotGradCSR(gof.Op): ...@@ -1971,6 +1970,9 @@ class StructuredDotGradCSR(gof.Op):
g_a_data[j_idx] = dot_val g_a_data[j_idx] = dot_val
out[0] = g_a_data out[0] = g_a_data
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub): def c_code(self, node, name, (_indices, _indptr, _d, _g), (_zout, ), sub):
if node.inputs[2].type.dtype in ('complex64', 'complex128'): if node.inputs[2].type.dtype in ('complex64', 'complex128'):
...@@ -1994,17 +1996,13 @@ class StructuredDotGradCSR(gof.Op): ...@@ -1994,17 +1996,13 @@ class StructuredDotGradCSR(gof.Op):
if( %(_d)s->dimensions[1] != %(_g)s->dimensions[1]) if( %(_d)s->dimensions[1] != %(_g)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "d and g have different numbers of columns"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "d and g have different numbers of columns"); %(fail)s;}
if (!%(_zout)s) if (!%(_zout)s
|| (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0]))
{ {
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_g)s->descr->type_num); %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, %(_indices)s->dimensions, %(_g)s->descr->type_num);
} }
if (%(_zout)s->dimensions[0] != %(_indices)s->dimensions[0])
{
PyErr_SetString(PyExc_NotImplementedError, "somehow _zout got the wrong size.. and I don't know how to resize it.");
%(fail)s;
}
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
npy_intp nnz = %(_indices)s->dimensions[0]; npy_intp nnz = %(_indices)s->dimensions[0];
// extract number of rows // extract number of rows
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论