提交 1143c9f8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Raise meaningful Python errors in C code

Instead of just crashing the program.
上级 b2b1d7f7
......@@ -77,7 +77,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s
}
}
assert (%(Z)s != %(A)s);
if (%(Z)s == %(A)s)
{
PyErr_SetString(PyExc_AssertionError, "%(Z)s != %(A)s");
%(fail)s
}
if (%(Z)s->descr->type_num == PyArray_FLOAT)
{
float * zoutdata = (float*)%(Z)s->data;
......@@ -163,7 +167,10 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(y)s->data), &Sy,
(double*)(%(Z)s->data), &Sz1);
}
else { assert(0); }
else {
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
%(fail)s
}
}
else if (%(Z)s->strides[1] == elemsize)
{
......@@ -186,9 +193,19 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(x)s->data), &Sx,
(double*)(%(Z)s->data), &Sz0);
}
else { assert(0); }
else
{
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
}
else { assert(0); }
}
}
......@@ -204,7 +221,7 @@ class CGer(BaseBLAS, Ger):
return code
def c_code_cache_version(self):
return (4,)
return (6,)
@local_optimizer([ger, ger_destructive])
......@@ -282,7 +299,11 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(fail)s
}
}
assert (%(zz)s != %(aa)s);
if (%(zz)s == %(aa)s)
{
PyErr_SetString(PyExc_AssertionError, "%(zz)s != %(aa)s");
%(fail)s
}
if (dbeta != 0)
{
if (%(zz)s->descr->type_num == PyArray_FLOAT)
......@@ -362,7 +383,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
}
else
{
assert(0);
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
else if (%(xx)s->strides[1] == elemsize)
......@@ -392,14 +414,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
}
else
{
assert(0);
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
}
}
else
{
// if xx is strided in both directions, then just do the gemv with a
// pair of for loops.
assert (0);
PyErr_SetString(PyExc_NotImplementedError, "double-strided matrix");
%(fail)s
}
}
else if (dbeta != 1.0)
......@@ -429,7 +453,7 @@ class CGemv(BaseBLAS, Gemv):
return code
def c_code_cache_version(self):
return (2,)
return (3,)
@local_optimizer([gemv_inplace, gemv_no_inplace])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论