提交 1143c9f8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Raise meaningful Python errors in C code

Instead of just crashing the program.
上级 b2b1d7f7
...@@ -77,7 +77,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -77,7 +77,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
%(fail)s %(fail)s
} }
} }
assert (%(Z)s != %(A)s); if (%(Z)s == %(A)s)
{
PyErr_SetString(PyExc_AssertionError, "%(Z)s != %(A)s");
%(fail)s
}
if (%(Z)s->descr->type_num == PyArray_FLOAT) if (%(Z)s->descr->type_num == PyArray_FLOAT)
{ {
float * zoutdata = (float*)%(Z)s->data; float * zoutdata = (float*)%(Z)s->data;
...@@ -163,7 +167,10 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -163,7 +167,10 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(y)s->data), &Sy, (double*)(%(y)s->data), &Sy,
(double*)(%(Z)s->data), &Sz1); (double*)(%(Z)s->data), &Sz1);
} }
else { assert(0); } else {
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
%(fail)s
}
} }
else if (%(Z)s->strides[1] == elemsize) else if (%(Z)s->strides[1] == elemsize)
{ {
...@@ -186,9 +193,19 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -186,9 +193,19 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(x)s->data), &Sx, (double*)(%(x)s->data), &Sx,
(double*)(%(Z)s->data), &Sz0); (double*)(%(Z)s->data), &Sz0);
} }
else { assert(0); } else
{
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"A is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
%(fail)s
} }
else { assert(0); }
} }
} }
...@@ -204,7 +221,7 @@ class CGer(BaseBLAS, Ger): ...@@ -204,7 +221,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (6,)
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
...@@ -282,7 +299,11 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -282,7 +299,11 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(fail)s %(fail)s
} }
} }
assert (%(zz)s != %(aa)s); if (%(zz)s == %(aa)s)
{
PyErr_SetString(PyExc_AssertionError, "%(zz)s != %(aa)s");
%(fail)s
}
if (dbeta != 0) if (dbeta != 0)
{ {
if (%(zz)s->descr->type_num == PyArray_FLOAT) if (%(zz)s->descr->type_num == PyArray_FLOAT)
...@@ -362,7 +383,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -362,7 +383,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else else
{ {
assert(0); PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
} }
} }
else if (%(xx)s->strides[1] == elemsize) else if (%(xx)s->strides[1] == elemsize)
...@@ -392,14 +414,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -392,14 +414,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else else
{ {
assert(0); PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
%(fail)s
} }
} }
else else
{ {
// if xx is strided in both directions, then just do the gemv with a // if xx is strided in both directions, then just do the gemv with a
// pair of for loops. // pair of for loops.
assert (0); PyErr_SetString(PyExc_NotImplementedError, "double-strided matrix");
%(fail)s
} }
} }
else if (dbeta != 1.0) else if (dbeta != 1.0)
...@@ -429,7 +453,7 @@ class CGemv(BaseBLAS, Gemv): ...@@ -429,7 +453,7 @@ class CGemv(BaseBLAS, Gemv):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论