提交 7ed1e437 authored 作者: Frederic's avatar Frederic 提交者: Yann N. Dauphin

Other speed up.

上级 9fbe1e88
...@@ -49,19 +49,19 @@ def multMatVect(v, A, m1, B, m2): ...@@ -49,19 +49,19 @@ def multMatVect(v, A, m1, B, m2):
multMatVect.dot_modulo = function([A_sym, s_sym, m_sym], multMatVect.dot_modulo = function([A_sym, s_sym, m_sym],
DotModulo()(A_sym, s_sym, m_sym)) DotModulo()(A_sym, s_sym, m_sym))
r = numpy.zeros_like(v) r = numpy.zeros_like(v)
# This way of calling the Theano fct is done to bypass Theano overhead. # This way of calling the Theano fct is done to bypass Theano overhead.
multMatVect.dot_modulo.input_storage[0].storage[0] = A f = multMatVect.dot_modulo
multMatVect.dot_modulo.input_storage[1].storage[0] = v[:3] f.input_storage[0].storage[0] = A
multMatVect.dot_modulo.input_storage[2].storage[0] = m1 f.input_storage[1].storage[0] = v[:3]
r[:3] = multMatVect.dot_modulo.fn()[0] f.input_storage[2].storage[0] = m1
r[:3] = f.fn()[0]
multMatVect.dot_modulo.input_storage[0].storage[0] = B f.input_storage[0].storage[0] = B
multMatVect.dot_modulo.input_storage[1].storage[0] = v[3:] f.input_storage[1].storage[0] = v[3:]
multMatVect.dot_modulo.input_storage[2].storage[0] = m2 f.input_storage[2].storage[0] = m2
r[3:] = multMatVect.dot_modulo.fn()[0] r[3:] = f.fn()[0]
return r return r
multMatVect.dot_modulo = None multMatVect.dot_modulo = None
...@@ -1025,7 +1025,8 @@ class MRG_RandomStreams(object): ...@@ -1025,7 +1025,8 @@ class MRG_RandomStreams(object):
def inc_rstate(self): def inc_rstate(self):
"""Update self.rstate to be skipped 2^134 steps forward to the next stream start""" """Update self.rstate to be skipped 2^134 steps forward to the next stream start"""
self.rstate = ff_2p134(self.rstate) #self.rstate = ff_2p134(self.rstate)
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == numpy.int32 assert self.rstate.dtype == numpy.int32
def get_substream_rstates(self, n_streams, inc_rstate=True): def get_substream_rstates(self, n_streams, inc_rstate=True):
...@@ -1037,7 +1038,8 @@ class MRG_RandomStreams(object): ...@@ -1037,7 +1038,8 @@ class MRG_RandomStreams(object):
rval = numpy.zeros((n_streams, 6), dtype='int32') rval = numpy.zeros((n_streams, 6), dtype='int32')
rval[0] = self.rstate rval[0] = self.rstate
for i in xrange(1, n_streams): for i in xrange(1, n_streams):
rval[i] = ff_2p72(rval[i - 1]) #rval[i] = ff_2p72(rval[i - 1])
rval[i] = multMatVect(rval[i - 1], A1p72, M1, A2p72, M2)
if inc_rstate: if inc_rstate:
self.inc_rstate() self.inc_rstate()
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论