提交 883e9905 authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

initial commit for fast DotModulo

上级 9bbf919e
...@@ -47,6 +47,79 @@ def multMatVect(v, A, m1, B, m2): ...@@ -47,6 +47,79 @@ def multMatVect(v, A, m1, B, m2):
return r return r
class DotModulo(Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, A, s, m):
return Apply(self, [A, s, m], [s.type()])
def perform(self, node, (A, s, m), (out, )):
out[0] = matVecModM(A, s, m)
#def c_code_cache_version(self):
# return (1,)
def c_code(self, node, name, (_A, _s, _m), (_z, ), sub):
return """
if (PyArray_NDIM(%(_A)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); %(fail)s;}
if (PyArray_NDIM(%(_s)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(v) != 1"); %(fail)s;}
if (PyArray_NDIM(%(_m)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(m) != 0"); %(fail)s;}
if( PyArray_DIMS(%(_A)s)[1] != PyArray_DIMS(%(_s)s)[0])
{PyErr_SetString(PyExc_NotImplementedError, "A and s shapes don't agree."); %(fail)s;}
if (!%(_z)s
|| (PyArray_DIMS(%(_z)s)[0] != PyArray_DIMS(%(_A)s)[0]))
{
{Py_XDECREF(%(_z)s);}
npy_intp dims[] = {0,};
dims[0] = PyArray_DIMS(%(_A)s)[0];
%(_z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, PyArray_TYPE(%(_s)s));
}
{ //makes it compile even though labels jump over variable definitions.
// A has size MxN, s has N, output M
npy_intp M = PyArray_DIMS(%(_A)s)[0];
npy_intp N = PyArray_DIMS(%(_A)s)[1];
const dtype_%(_A)s* __restrict__ DA = (dtype_%(_A)s*)PyArray_DATA(%(_A)s);
dtype_%(_s)s* __restrict__ Ds = (dtype_%(_s)s*)PyArray_DATA(%(_s)s);
dtype_%(_z)s* __restrict__ Dz = (dtype_%(_z)s*)PyArray_DATA(%(_z)s);
const dtype_%(_m)s m = ((dtype_%(_m)s*)PyArray_DATA(%(_m)s))[0];
npy_intp SA = PyArray_STRIDES(%(_A)s)[1] / PyArray_DESCR(%(_A)s)->elsize;
npy_intp Ss = PyArray_STRIDES(%(_s)s)[0] / PyArray_DESCR(%(_s)s)->elsize;
npy_intp Sz = PyArray_STRIDES(%(_z)s)[0] / PyArray_DESCR(%(_z)s)->elsize;
memset(Dz, 0, M*sizeof(dtype_%(_z)s));
for (npy_int32 i = 0; i < M; ++i)
{
const dtype_%(_A)s* __restrict__ Ak = (dtype_%(_A)s*)(PyArray_BYTES(%(_A)s) + PyArray_STRIDES(%(_A)s)[0] * i);
for (npy_int32 j = 0; j < N; ++j)
{
npy_intp r = (Dz[i * Sz] + (npy_int64)(Ds[j * Ss] * Ak[j * SA])) %% m;
if (r >= 0) {
Dz[i * Sz] = r;
}
else {
Dz[i * Sz] = r + m;
}
}
}
}
""" % dict(locals(), **sub)
#MRG31k3p #MRG31k3p
#generator constants : #generator constants :
M1 = numpy.int32(2147483647) #2^31 - 1 M1 = numpy.int32(2147483647) #2^31 - 1
......
...@@ -874,3 +874,21 @@ def test_gradient_scan(): ...@@ -874,3 +874,21 @@ def test_gradient_scan():
gw = theano.grad(tensor.sum(values[-1]), w) gw = theano.grad(tensor.sum(values[-1]), w)
f = theano.function([x], gw) f = theano.function([x], gw)
f(numpy.arange(1, dtype='float32')) f(numpy.arange(1, dtype='float32'))
def test_multMatVect():
A = tensor.imatrix('A')
s = tensor.ivector('s')
m = tensor.iscalar('m')
g0 = rng_mrg.DotModulo()(A, s, m)
f0 = theano.function([A, s, m], g0)
A = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int32')
s = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32')
m = numpy.random.randint(numpy.iinfo(numpy.int32).max)
r_a = rng_mrg.matVecModM(A, s, m)
r_b = f0(A, s, m)
assert numpy.allclose(r_a, r_b)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论