提交 b44afc95 authored 作者: Frederic's avatar Frederic

Make the DotModulo do it 2 times to lower the number of fct call.

Speed up from 0.085s to 0.055s in mrg init.
上级 8a3734fe
......@@ -44,24 +44,25 @@ def multMatVect(v, A, m1, B, m2):
A_sym = tensor.lmatrix('A')
s_sym = tensor.ivector('s')
m_sym = tensor.iscalar('m')
A2_sym = tensor.lmatrix('A2')
s2_sym = tensor.ivector('s2')
m2_sym = tensor.iscalar('m2')
# We borrow the output as we will copy the answer elsewhere
o = Out(DotModulo()(A_sym, s_sym, m_sym), borrow=True)
multMatVect.dot_modulo = function([A_sym, s_sym, m_sym], o)
r = numpy.zeros_like(v)
o = Out(DotModulo()(A_sym, s_sym, m_sym, A2_sym, s2_sym, m2_sym),
borrow=True)
multMatVect.dot_modulo = function(
[A_sym, s_sym, m_sym, A2_sym, s2_sym, m2_sym], o)
# This way of calling the Theano fct is done to bypass Theano overhead.
f = multMatVect.dot_modulo
f.input_storage[0].storage[0] = A
f.input_storage[1].storage[0] = v[:3]
f.input_storage[2].storage[0] = m1
f.input_storage[3].storage[0] = B
f.input_storage[4].storage[0] = v[3:]
f.input_storage[5].storage[0] = m2
f.fn()
r[:3] = f.output_storage[0].storage[0]
f.input_storage[0].storage[0] = B
f.input_storage[1].storage[0] = v[3:]
f.input_storage[2].storage[0] = m2
f.fn()
r[3:] = f.output_storage[0].storage[0]
r = f.output_storage[0].storage[0]
return r
multMatVect.dot_modulo = None
......@@ -71,6 +72,8 @@ class DotModulo(Op):
"""
Efficient and numerically stable implementation of a dot product followed
by a modulo operation. This performs the same function as matVecModM.
We do this 2 times on 2 triple inputs and concatenating the output
"""
def __eq__(self, other):
return type(self) == type(other)
......@@ -78,30 +81,40 @@ class DotModulo(Op):
def __hash__(self):
return hash(type(self))
def make_node(self, A, s, m):
return Apply(self, [A, s, m], [s.type()])
def make_node(self, A, s, m, A2, s2, m2):
return Apply(self, [A, s, m, A2, s2, m2], [s.type()])
def perform(self, node, (A, s, m), (out, )):
out[0] = matVecModM(A, s, m)
def perform(self, node, (A, s, m, A2, s2, m2), (out, )):
o1 = matVecModM(A, s, m)
o2 = matVecModM(A2, s2, m2)
out[0] = numpy.concatenate((o1, o2))
def c_code_cache_version(self):
return
return (5,)
def c_code(self, node, name, (_A, _s, _m), (_z, ), sub):
def c_code(self, node, name, (_A, _s, _m, _A2, _s2, _m2), (_z, ), sub):
return """
int osize = -1;
if (PyArray_NDIM(%(_A)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); %(fail)s;}
if (PyArray_NDIM(%(_s)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(v) != 1"); %(fail)s;}
if (PyArray_NDIM(%(_m)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(m) != 0"); %(fail)s;}
if (PyArray_NDIM(%(_A2)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(A2) != 2"); %(fail)s;}
if (PyArray_NDIM(%(_s2)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(v2) != 1"); %(fail)s;}
if (PyArray_NDIM(%(_m2)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(m2) != 0"); %(fail)s;}
if( PyArray_DIMS(%(_A)s)[1] != PyArray_DIMS(%(_s)s)[0])
{PyErr_SetString(PyExc_NotImplementedError, "A and s shapes don't agree."); %(fail)s;}
if( PyArray_DIMS(%(_A2)s)[1] != PyArray_DIMS(%(_s2)s)[0])
{PyErr_SetString(PyExc_NotImplementedError, "A2 and s2 shapes don't agree."); %(fail)s;}
osize = PyArray_DIMS(%(_A)s)[0] + PyArray_DIMS(%(_A2)s)[0];
if (!%(_z)s
|| (PyArray_DIMS(%(_z)s)[0] != PyArray_DIMS(%(_A)s)[0]))
|| (PyArray_DIMS(%(_z)s)[0] != osize))
{
{Py_XDECREF(%(_z)s);}
npy_intp dims[] = {0,};
dims[0] = PyArray_DIMS(%(_A)s)[0];
dims[0] = osize;
%(_z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, PyArray_TYPE(%(_s)s));
}
......@@ -137,6 +150,38 @@ class DotModulo(Op):
}
}
//redo it with the second triple of inputs
{
// A has size MxN, s has N, output M
npy_intp M = PyArray_DIMS(%(_A2)s)[0];
npy_intp N = PyArray_DIMS(%(_A2)s)[1];
const dtype_%(_A2)s* __restrict__ DA = (dtype_%(_A2)s*)PyArray_DATA(%(_A2)s);
dtype_%(_s2)s* __restrict__ Ds = (dtype_%(_s2)s*)PyArray_DATA(%(_s2)s);
const dtype_%(_m2)s m = ((dtype_%(_m2)s*)PyArray_DATA(%(_m2)s))[0];
npy_intp SA = PyArray_STRIDES(%(_A2)s)[1] / PyArray_DESCR(%(_A2)s)->elsize;
npy_intp Ss = PyArray_STRIDES(%(_s2)s)[0] / PyArray_DESCR(%(_s2)s)->elsize;
npy_intp Sz = PyArray_STRIDES(%(_z)s)[0] / PyArray_DESCR(%(_z)s)->elsize;
dtype_%(_z)s* __restrict__ Dz = (dtype_%(_z)s*)PyArray_DATA(%(_z)s) + PyArray_DIMS(%(_A)s)[0] * Sz;
for (npy_int32 i = 0; i < M; ++i)
{
const dtype_%(_A2)s* __restrict__ Ak = (dtype_%(_A2)s*)(PyArray_BYTES(%(_A2)s) + PyArray_STRIDES(%(_A2)s)[0] * i);
npy_int64 r = 0;
for (npy_int32 j = 0; j < N; ++j)
{
r += (npy_int64)(Ds[j * Ss] * (npy_int64)(Ak[j * SA])) %% m;
}
Dz[i * Sz] = r %% m;
}
}
""" % dict(locals(), **sub)
......@@ -808,24 +853,20 @@ class MRG_RandomStreams(object):
if multMatVect.dot_modulo is None:
multMatVect(rval[0], A1p72, M1, A2p72, M2)
# This way of calling the Theano fct is done to bypass Theano overhead.
f = multMatVect.dot_modulo
f.input_storage[0].storage[0] = A1p72
f.input_storage[2].storage[0] = M1
f.input_storage[3].storage[0] = A2p72
f.input_storage[5].storage[0] = M2
for i in xrange(1, n_streams):
# Inline the following call to bypass Python overhead
#rval[i] = ff_2p72(rval[i - 1])
r = rval[i]
v = rval[i - 1]
# This way of calling the Theano fct is done to bypass Theano overhead.
f.input_storage[0].storage[0] = A1p72
f.input_storage[1].storage[0] = v[:3]
f.input_storage[2].storage[0] = M1
f.fn()
r[:3] = f.output_storage[0].storage[0]
f.input_storage[0].storage[0] = A2p72
f.input_storage[1].storage[0] = v[3:]
f.input_storage[2].storage[0] = M2
f.input_storage[4].storage[0] = v[3:]
f.fn()
r[3:] = f.output_storage[0].storage[0]
rval[i] = f.output_storage[0].storage[0]
if inc_rstate:
self.inc_rstate()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论