提交 20a0d476 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 228f9bd4
......@@ -49,14 +49,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
if (%(A)s->dimensions[0] != %(x)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;}
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]");
%(fail)s;
}
if (%(A)s->dimensions[1] != %(y)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;}
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]");
%(fail)s;
}
if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); %(fail)s;}
else
{
PyErr_SetString(PyExc_NotImplementedError, "complex CGer");
%(fail)s;
}
// copy A if !self.destructive or A is fully strided
if (!%(destructive)s
......@@ -78,9 +89,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
&& (%(Z)s->strides[1] != elemsize)))
{
if (%(Z)s) Py_XDECREF(%(Z)s);
%(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(A)s));
%(Z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims,
PyArray_TYPE(%(A)s));
if(!%(Z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc ger output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc ger output");
%(fail)s
}
}
......@@ -123,7 +136,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
......@@ -183,7 +197,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(Z)s->data), &Sz1);
}
else {
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
}
......@@ -210,7 +225,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
else
{
PyErr_SetString(PyExc_NotImplementedError, "not float nor double");
PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s
}
}
......@@ -225,6 +241,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
""" % locals()
class CGer(BaseBLAS, Ger):
def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp
......@@ -250,13 +267,13 @@ def use_c_ger(node):
node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)])
def make_c_ger_destructive(node):
if node.op == CGer(False):
return [CGer(True)(*node.inputs)]
####### ####### #######
# GEMV
####### ####### #######
......@@ -275,15 +292,30 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
double dbeta;
if (%(aa)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1"); %(fail)s;}
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1");
%(fail)s;
}
if (%(xx)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2"); %(fail)s;}
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2");
%(fail)s;
}
if (%(yy)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1"); %(fail)s;}
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1");
%(fail)s;
}
if (%(alpha)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0"); %(fail)s;}
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0");
%(fail)s;
}
if (%(beta)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0"); %(fail)s;}
{
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0");
%(fail)s;
}
if (%(aa)s->descr->type_num != %(xx)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
......@@ -291,13 +323,24 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;}
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]");
%(fail)s;
}
if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;}
{
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]");
%(fail)s;
}
if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); %(fail)s;}
else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s;
}
fbeta = dbeta = ((dtype_%(beta)s*)%(beta)s->data)[0];
......@@ -311,7 +354,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
%(aa)s->dimensions, type_num_%(aa)s);
if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemv output");
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output");
%(fail)s
}
}
......@@ -346,7 +390,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
fbeta = dbeta = 1.0;
......@@ -404,7 +449,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
dims[0] = Nx0;
dims[1] = Nx1;
PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(%(xx)s);
PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(
%(xx)s);
if (!xx_copy)
%(fail)s
Py_XDECREF(%(xx)s);
......@@ -438,7 +484,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
......@@ -446,7 +493,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{
if (%(xx)s->descr->type_num == PyArray_FLOAT)
{
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1);
//fprintf(stderr, "B %%i %%i %%i %%i\\n",
Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
//fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
......@@ -469,15 +517,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
}
else
{
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype");
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
}
else
{
PyErr_SetString(PyExc_AssertionError,
"xx is a double-strided matrix, and should have been copied "
"into a memory-contiguous one.");
"xx is a double-strided matrix, and should have been "
"copied into a memory-contiguous one.");
%(fail)s
}
}
......@@ -530,7 +579,6 @@ def make_c_gemv_destructive(node):
return [CGemv(inplace=True)(*node.inputs)]
####### ####### #######
# Optimizers
####### ####### #######
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论