提交 ed40cf13 authored 作者: Frederic's avatar Frederic

more pep8

上级 f24372bf
...@@ -106,16 +106,16 @@ class DotModulo(Op): ...@@ -106,16 +106,16 @@ class DotModulo(Op):
if(!%(_z)s){%(fail)s;} if(!%(_z)s){%(fail)s;}
{ //makes it compile even though labels jump over variable definitions. { //makes it compile even though labels jump over variable definitions.
// A has size MxN, s has N, output M // A has size MxN, s has N, output M
npy_intp M = PyArray_DIMS(%(_A)s)[0]; npy_intp M = PyArray_DIMS(%(_A)s)[0];
npy_intp N = PyArray_DIMS(%(_A)s)[1]; npy_intp N = PyArray_DIMS(%(_A)s)[1];
const dtype_%(_A)s* __restrict__ DA = (dtype_%(_A)s*)PyArray_DATA(%(_A)s); const dtype_%(_A)s* __restrict__ DA = (dtype_%(_A)s*)PyArray_DATA(%(_A)s);
dtype_%(_s)s* __restrict__ Ds = (dtype_%(_s)s*)PyArray_DATA(%(_s)s); dtype_%(_s)s* __restrict__ Ds = (dtype_%(_s)s*)PyArray_DATA(%(_s)s);
dtype_%(_z)s* __restrict__ Dz = (dtype_%(_z)s*)PyArray_DATA(%(_z)s); dtype_%(_z)s* __restrict__ Dz = (dtype_%(_z)s*)PyArray_DATA(%(_z)s);
const dtype_%(_m)s m = ((dtype_%(_m)s*)PyArray_DATA(%(_m)s))[0]; const dtype_%(_m)s m = ((dtype_%(_m)s*)PyArray_DATA(%(_m)s))[0];
npy_intp SA = PyArray_STRIDES(%(_A)s)[1] / PyArray_DESCR(%(_A)s)->elsize; npy_intp SA = PyArray_STRIDES(%(_A)s)[1] / PyArray_DESCR(%(_A)s)->elsize;
npy_intp Ss = PyArray_STRIDES(%(_s)s)[0] / PyArray_DESCR(%(_s)s)->elsize; npy_intp Ss = PyArray_STRIDES(%(_s)s)[0] / PyArray_DESCR(%(_s)s)->elsize;
npy_intp Sz = PyArray_STRIDES(%(_z)s)[0] / PyArray_DESCR(%(_z)s)->elsize; npy_intp Sz = PyArray_STRIDES(%(_z)s)[0] / PyArray_DESCR(%(_z)s)->elsize;
...@@ -123,14 +123,14 @@ class DotModulo(Op): ...@@ -123,14 +123,14 @@ class DotModulo(Op):
for (npy_int32 i = 0; i < M; ++i) for (npy_int32 i = 0; i < M; ++i)
{ {
const dtype_%(_A)s* __restrict__ Ak = (dtype_%(_A)s*)(PyArray_BYTES(%(_A)s) + PyArray_STRIDES(%(_A)s)[0] * i); const dtype_%(_A)s* __restrict__ Ak = (dtype_%(_A)s*)(PyArray_BYTES(%(_A)s) + PyArray_STRIDES(%(_A)s)[0] * i);
npy_int64 r = 0; npy_int64 r = 0;
for (npy_int32 j = 0; j < N; ++j) for (npy_int32 j = 0; j < N; ++j)
{ {
r += (npy_int64)(Ds[j * Ss] * (npy_int64)(Ak[j * SA])) %% m; r += (npy_int64)(Ds[j * Ss] * (npy_int64)(Ak[j * SA])) %% m;
} }
Dz[i * Sz] = r %% m; Dz[i * Sz] = r %% m;
} }
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论