提交 73db397f authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

integration in module

上级 bad43cc2
...@@ -10,7 +10,7 @@ import warnings ...@@ -10,7 +10,7 @@ import warnings
import numpy import numpy
from theano import Op, Apply, shared, config, Variable from theano import Op, Apply, shared, config, Variable
from theano import gradient from theano import gradient, function
from theano import tensor from theano import tensor
from theano.tensor import (raw_random, TensorType, as_tensor_variable, from theano.tensor import (raw_random, TensorType, as_tensor_variable,
get_vector_length, cast, opt, scal) get_vector_length, cast, opt, scal)
...@@ -34,21 +34,45 @@ def matVecModM(A, s, m): ...@@ -34,21 +34,45 @@ def matVecModM(A, s, m):
return numpy.int32(numpy.sum((A*s) % m, 1) % m) return numpy.int32(numpy.sum((A*s) % m, 1) % m)
dot_modulo = None
def multMatVect(v, A, m1, B, m2): def multMatVect(v, A, m1, B, m2):
"""
multiply the first half of v by A with a modulo of m1
and the second half by B with a modulo of m2
Note: The parameters of dot_modulo are passed implicitly because passing
them explicitly takes more time then running the function's C-code.
"""
#multiply the first half of v by A with a modulo of m1 #multiply the first half of v by A with a modulo of m1
#and the second half by B with a modulo of m2 #and the second half by B with a modulo of m2
err_orig = numpy.seterr(over='ignore') global dot_modulo
try: if dot_modulo == None:
r = numpy.zeros_like(v) A_sym = tensor.lmatrix('A')
r[:3] = matVecModM(A, v[:3], m1) s_sym = tensor.ivector('s')
r[3:] = matVecModM(B, v[3:], m2) m_sym = tensor.iscalar('m')
finally:
numpy.seterr(**err_orig) dot_modulo = function([A_sym, s_sym, m_sym],
DotModulo()(A_sym, s_sym, m_sym))
r = numpy.zeros_like(v)
dot_modulo.input_storage[0].storage[0] = A
dot_modulo.input_storage[1].storage[0] = v[:3]
dot_modulo.input_storage[2].storage[0] = m1
r[:3] = dot_modulo.fn()[0]
dot_modulo.input_storage[0].storage[0] = B
dot_modulo.input_storage[1].storage[0] = v[3:]
dot_modulo.input_storage[2].storage[0] = m2
r[3:] = dot_modulo.fn()[0]
return r return r
class DotModulo(Op): class DotModulo(Op):
"""
Efficient and numerically stable implementation of a dot product followed
by a modulo operation. This performs the same function as matVecModM.
"""
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
...@@ -61,8 +85,8 @@ class DotModulo(Op): ...@@ -61,8 +85,8 @@ class DotModulo(Op):
def perform(self, node, (A, s, m), (out, )): def perform(self, node, (A, s, m), (out, )):
out[0] = matVecModM(A, s, m) out[0] = matVecModM(A, s, m)
#def c_code_cache_version(self): def c_code_cache_version(self):
# return (2,) return (3,)
def c_code(self, node, name, (_A, _s, _m), (_z, ), sub): def c_code(self, node, name, (_A, _s, _m), (_z, ), sub):
return """ return """
...@@ -105,7 +129,7 @@ class DotModulo(Op): ...@@ -105,7 +129,7 @@ class DotModulo(Op):
for (npy_int32 j = 0; j < N; ++j) for (npy_int32 j = 0; j < N; ++j)
{ {
r += Ds[j * Ss] * (npy_int64)(Ak[j * SA]); r += (npy_int64)(Ds[j * Ss] * (npy_int64)(Ak[j * SA])) %% m;
} }
Dz[i * Sz] = r %% m; Dz[i * Sz] = r %% m;
...@@ -117,13 +141,13 @@ class DotModulo(Op): ...@@ -117,13 +141,13 @@ class DotModulo(Op):
#MRG31k3p #MRG31k3p
#generator constants : #generator constants :
M1 = numpy.int32(2147483647) #2^31 - 1 M1 = numpy.asarray(numpy.int32(2147483647)) #2^31 - 1
M2 = numpy.int32(2147462579) #2^31 - 21069 M2 = numpy.asarray(numpy.int32(2147462579)) #2^31 - 21069
MASK12 = numpy.int32(511) #2^9 - 1 MASK12 = numpy.int32(511) #2^9 - 1
MASK13 = numpy.int32(16777215) #2^24 - 1 MASK13 = numpy.int32(16777215) #2^24 - 1
MASK2 = numpy.int32(65535) #2^16 - 1 MASK2 = numpy.int32(65535) #2^16 - 1
MULT2 = numpy.int32(21069) MULT2 = numpy.int32(21069)
NORM = 4.656612873077392578125e-10; #1./2^31 NORM = 4.656612873077392578125e-10; #1./2^31
#A1p0 = numpy.asarray([[0, 4194304, 129], [1, 0, 0], [0, 1, 0]], #A1p0 = numpy.asarray([[0, 4194304, 129], [1, 0, 0], [0, 1, 0]],
# dtype='int64') # dtype='int64')
......
...@@ -886,9 +886,13 @@ def test_multMatVect(): ...@@ -886,9 +886,13 @@ def test_multMatVect():
A = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int32') A = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, (3, 3)).astype('int32')
s = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32') s = numpy.random.randint(0, numpy.iinfo(numpy.int32).max, 3).astype('int32')
m = numpy.random.randint(numpy.iinfo(numpy.int32).max) m = numpy.asarray(numpy.random.randint(numpy.iinfo(numpy.int32).max), dtype="int32")
f0.input_storage[0].storage[0] = A
f0.input_storage[1].storage[0] = s
f0.input_storage[2].storage[0] = m
r_a = rng_mrg.matVecModM(A, s, m) r_a = rng_mrg.matVecModM(A, s, m)
r_b = f0(A, s, m) r_b = f0.fn()
assert numpy.allclose(r_a, r_b) assert numpy.allclose(r_a, r_b)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论