提交 20a0d476 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 228f9bd4
...@@ -49,14 +49,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -49,14 +49,25 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "A vs. y"); %(fail)s; }
if (%(A)s->dimensions[0] != %(x)s->dimensions[0]) if (%(A)s->dimensions[0] != %(x)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;} {
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]");
%(fail)s;
}
if (%(A)s->dimensions[1] != %(y)s->dimensions[0]) if (%(A)s->dimensions[1] != %(y)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;} {
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]");
%(fail)s;
}
if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; } if (%(A)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;} else if (%(A)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); %(fail)s;} else
{
PyErr_SetString(PyExc_NotImplementedError, "complex CGer");
%(fail)s;
}
// copy A if !self.destructive or A is fully strided // copy A if !self.destructive or A is fully strided
if (!%(destructive)s if (!%(destructive)s
...@@ -78,9 +89,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -78,9 +89,11 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
&& (%(Z)s->strides[1] != elemsize))) && (%(Z)s->strides[1] != elemsize)))
{ {
if (%(Z)s) Py_XDECREF(%(Z)s); if (%(Z)s) Py_XDECREF(%(Z)s);
%(Z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(A)s)); %(Z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims,
PyArray_TYPE(%(A)s));
if(!%(Z)s) { if(!%(Z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc ger output"); PyErr_SetString(PyExc_MemoryError,
"failed to alloc ger output");
%(fail)s %(fail)s
} }
} }
...@@ -123,7 +136,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -123,7 +136,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
} }
...@@ -183,7 +197,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -183,7 +197,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
(double*)(%(Z)s->data), &Sz1); (double*)(%(Z)s->data), &Sz1);
} }
else { else {
PyErr_SetString(PyExc_NotImplementedError, "not float nor double"); PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s %(fail)s
} }
} }
...@@ -210,7 +225,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -210,7 +225,8 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
else else
{ {
PyErr_SetString(PyExc_NotImplementedError, "not float nor double"); PyErr_SetString(PyExc_NotImplementedError,
"not float nor double");
%(fail)s %(fail)s
} }
} }
...@@ -225,6 +241,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -225,6 +241,7 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
""" % locals() """ % locals()
class CGer(BaseBLAS, Ger): class CGer(BaseBLAS, Ger):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
A, a, x, y = inp A, a, x, y = inp
...@@ -250,13 +267,13 @@ def use_c_ger(node): ...@@ -250,13 +267,13 @@ def use_c_ger(node):
node.outputs[0].dtype in ['float32', 'float64']): node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(True)(*node.inputs)] return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)]) @local_optimizer([CGer(False)])
def make_c_ger_destructive(node): def make_c_ger_destructive(node):
if node.op == CGer(False): if node.op == CGer(False):
return [CGer(True)(*node.inputs)] return [CGer(True)(*node.inputs)]
####### ####### ####### ####### ####### #######
# GEMV # GEMV
####### ####### ####### ####### ####### #######
...@@ -275,15 +292,30 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -275,15 +292,30 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
double dbeta; double dbeta;
if (%(aa)s->nd != 1) if (%(aa)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1"); %(fail)s;} {
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(aa) != 1");
%(fail)s;
}
if (%(xx)s->nd != 2) if (%(xx)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2"); %(fail)s;} {
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(xx) != 2");
%(fail)s;
}
if (%(yy)s->nd != 1) if (%(yy)s->nd != 1)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1"); %(fail)s;} {
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(yy) != 1");
%(fail)s;
}
if (%(alpha)s->nd != 0) if (%(alpha)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0"); %(fail)s;} {
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(alpha) != 0");
%(fail)s;
}
if (%(beta)s->nd != 0) if (%(beta)s->nd != 0)
{PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0"); %(fail)s;} {
PyErr_SetString(PyExc_NotImplementedError, "Gemv: rank(beta) != 0");
%(fail)s;
}
if (%(aa)s->descr->type_num != %(xx)s->descr->type_num) if (%(aa)s->descr->type_num != %(xx)s->descr->type_num)
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. xx"); %(fail)s; }
...@@ -291,13 +323,24 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -291,13 +323,24 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; } { PyErr_SetString(PyExc_TypeError, "Gemv: aa vs. yy"); %(fail)s; }
if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0]) if (%(xx)s->dimensions[0] != %(aa)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); %(fail)s;} {
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[0] != x.shape[0]");
%(fail)s;
}
if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0]) if (%(xx)s->dimensions[1] != %(yy)s->dimensions[0])
{PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); %(fail)s;} {
PyErr_SetString(PyExc_ValueError,
"Shape mismatch: A.shape[1] != y.shape[0]");
%(fail)s;
}
if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; } if (%(aa)s->descr->type_num == PyArray_DOUBLE) { elemsize = 8; }
else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;} else if (%(aa)s->descr->type_num == PyArray_FLOAT) { elemsize = 4;}
else {PyErr_SetString(PyExc_NotImplementedError, "complex Gemv"); %(fail)s;} else {
PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
%(fail)s;
}
fbeta = dbeta = ((dtype_%(beta)s*)%(beta)s->data)[0]; fbeta = dbeta = ((dtype_%(beta)s*)%(beta)s->data)[0];
...@@ -311,7 +354,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -311,7 +354,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
%(zz)s = (PyArrayObject*)PyArray_SimpleNew(1, %(zz)s = (PyArrayObject*)PyArray_SimpleNew(1,
%(aa)s->dimensions, type_num_%(aa)s); %(aa)s->dimensions, type_num_%(aa)s);
if(!%(zz)s) { if(!%(zz)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc gemv output"); PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemv output");
%(fail)s %(fail)s
} }
} }
...@@ -346,7 +390,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -346,7 +390,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
fbeta = dbeta = 1.0; fbeta = dbeta = 1.0;
...@@ -404,7 +449,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -404,7 +449,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
dims[0] = Nx0; dims[0] = Nx0;
dims[1] = Nx1; dims[1] = Nx1;
PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(%(xx)s); PyArrayObject * xx_copy = (PyArrayObject *) PyArray_Copy(
%(xx)s);
if (!xx_copy) if (!xx_copy)
%(fail)s %(fail)s
Py_XDECREF(%(xx)s); Py_XDECREF(%(xx)s);
...@@ -438,7 +484,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -438,7 +484,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
} }
...@@ -446,7 +493,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -446,7 +493,8 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
{ {
if (%(xx)s->descr->type_num == PyArray_FLOAT) if (%(xx)s->descr->type_num == PyArray_FLOAT)
{ {
//fprintf(stderr, "B %%i %%i %%i %%i\\n", Nz0, Nz1, Sz0, Sz1); //fprintf(stderr, "B %%i %%i %%i %%i\\n",
Nz0, Nz1, Sz0, Sz1);
float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0]; float alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
//fprintf(stderr, "alpha=%%f\\n", alpha); //fprintf(stderr, "alpha=%%f\\n", alpha);
//fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy); //fprintf(stderr, "sx sy %%i %%i\\n", Sx, Sy);
...@@ -469,15 +517,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail): ...@@ -469,15 +517,16 @@ def gemv_c_code(aa, xx, yy, zz, alpha, beta, destructive, fail):
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, "neither float nor double dtype"); PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s %(fail)s
} }
} }
else else
{ {
PyErr_SetString(PyExc_AssertionError, PyErr_SetString(PyExc_AssertionError,
"xx is a double-strided matrix, and should have been copied " "xx is a double-strided matrix, and should have been "
"into a memory-contiguous one."); "copied into a memory-contiguous one.");
%(fail)s %(fail)s
} }
} }
...@@ -530,7 +579,6 @@ def make_c_gemv_destructive(node): ...@@ -530,7 +579,6 @@ def make_c_gemv_destructive(node):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
####### ####### ####### ####### ####### #######
# Optimizers # Optimizers
####### ####### ####### ####### ####### #######
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论