提交 bad43cc2 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

faster dotmodulo

上级 c052575f
......@@ -62,7 +62,7 @@ class DotModulo(Op):
out[0] = matVecModM(A, s, m)
#def c_code_cache_version(self):
# return (1,)
# return (2,)
def c_code(self, node, name, (_A, _s, _m), (_z, ), sub):
return """
......@@ -97,23 +97,18 @@ class DotModulo(Op):
npy_intp Ss = PyArray_STRIDES(%(_s)s)[0] / PyArray_DESCR(%(_s)s)->elsize;
npy_intp Sz = PyArray_STRIDES(%(_z)s)[0] / PyArray_DESCR(%(_z)s)->elsize;
memset(Dz, 0, M*sizeof(dtype_%(_z)s));
for (npy_int32 i = 0; i < M; ++i)
{
const dtype_%(_A)s* __restrict__ Ak = (dtype_%(_A)s*)(PyArray_BYTES(%(_A)s) + PyArray_STRIDES(%(_A)s)[0] * i);
npy_int64 r = 0;
for (npy_int32 j = 0; j < N; ++j)
{
npy_intp r = (Dz[i * Sz] + (npy_int64)(Ds[j * Ss]) * (npy_int64)(Ak[j * SA])) %% m;
if (r >= 0) {
Dz[i * Sz] = r;
}
else {
Dz[i * Sz] = r + m;
}
r += Ds[j * Ss] * (npy_int64)(Ak[j * SA]);
}
Dz[i * Sz] = r %% m;
}
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论