提交 9c49c1cb authored 作者: Frederic's avatar Frederic

some pep8

上级 f2d8bb22
...@@ -36,29 +36,29 @@ def multMatVect(v, A, m1, B, m2): ...@@ -36,29 +36,29 @@ def multMatVect(v, A, m1, B, m2):
""" """
multiply the first half of v by A with a modulo of m1 multiply the first half of v by A with a modulo of m1
and the second half by B with a modulo of m2 and the second half by B with a modulo of m2
Note: The parameters of dot_modulo are passed implicitly because passing Note: The parameters of dot_modulo are passed implicitly because passing
them explicitly takes more time then running the function's C-code. them explicitly takes more time then running the function's C-code.
""" """
if multMatVect.dot_modulo == None: if multMatVect.dot_modulo is None:
A_sym = tensor.lmatrix('A') A_sym = tensor.lmatrix('A')
s_sym = tensor.ivector('s') s_sym = tensor.ivector('s')
m_sym = tensor.iscalar('m') m_sym = tensor.iscalar('m')
multMatVect.dot_modulo = function([A_sym, s_sym, m_sym], multMatVect.dot_modulo = function([A_sym, s_sym, m_sym],
DotModulo()(A_sym, s_sym, m_sym)) DotModulo()(A_sym, s_sym, m_sym))
r = numpy.zeros_like(v) r = numpy.zeros_like(v)
multMatVect.dot_modulo.input_storage[0].storage[0] = A multMatVect.dot_modulo.input_storage[0].storage[0] = A
multMatVect.dot_modulo.input_storage[1].storage[0] = v[:3] multMatVect.dot_modulo.input_storage[1].storage[0] = v[:3]
multMatVect.dot_modulo.input_storage[2].storage[0] = m1 multMatVect.dot_modulo.input_storage[2].storage[0] = m1
r[:3] = multMatVect.dot_modulo.fn()[0] r[:3] = multMatVect.dot_modulo.fn()[0]
multMatVect.dot_modulo.input_storage[0].storage[0] = B multMatVect.dot_modulo.input_storage[0].storage[0] = B
multMatVect.dot_modulo.input_storage[1].storage[0] = v[3:] multMatVect.dot_modulo.input_storage[1].storage[0] = v[3:]
multMatVect.dot_modulo.input_storage[2].storage[0] = m2 multMatVect.dot_modulo.input_storage[2].storage[0] = m2
r[3:] = multMatVect.dot_modulo.fn()[0] r[3:] = multMatVect.dot_modulo.fn()[0]
return r return r
multMatVect.dot_modulo = None multMatVect.dot_modulo = None
...@@ -73,13 +73,13 @@ class DotModulo(Op): ...@@ -73,13 +73,13 @@ class DotModulo(Op):
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, A, s, m): def make_node(self, A, s, m):
return Apply(self, [A, s, m], [s.type()]) return Apply(self, [A, s, m], [s.type()])
def perform(self, node, (A, s, m), (out, )): def perform(self, node, (A, s, m), (out, )):
out[0] = matVecModM(A, s, m) out[0] = matVecModM(A, s, m)
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (5,)
...@@ -185,42 +185,41 @@ def mrg_next_value(rstate, new_rstate): ...@@ -185,42 +185,41 @@ def mrg_next_value(rstate, new_rstate):
x11, x12, x13, x21, x22, x23 = rstate x11, x12, x13, x21, x22, x23 = rstate
assert type(x11) == numpy.int32 assert type(x11) == numpy.int32
#i0, i7, i9, i15, i16, i22, i24 = [numpy.int32(i) for i in (0, 7, 9, 15, 16, 22, 24)]
i0, i7, i9, i15, i16, i22, i24 = np_int32_vals i0, i7, i9, i15, i16, i22, i24 = np_int32_vals
#first component #first component
y1 = (((x12 & MASK12) << i22) + (x12 >> i9) + y1 = (((x12 & MASK12) << i22) + (x12 >> i9) +
((x13 & MASK13) << i7) + (x13 >> i24)) ((x13 & MASK13) << i7) + (x13 >> i24))
assert type(y1) == numpy.int32 assert type(y1) == numpy.int32
if (y1 < 0 or y1 >= M1): #must also check overflow if (y1 < 0 or y1 >= M1): # must also check overflow
y1 -= M1; y1 -= M1
y1 += x13; y1 += x13
if (y1 < 0 or y1 >= M1): if (y1 < 0 or y1 >= M1):
y1 -= M1; y1 -= M1
x13 = x12; x13 = x12
x12 = x11; x12 = x11
x11 = y1; x11 = y1
#second component #second component
y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16)); y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16))
assert type(y1) == numpy.int32 assert type(y1) == numpy.int32
if (y1 < 0 or y1 >= M2): if (y1 < 0 or y1 >= M2):
y1 -= M2; y1 -= M2
y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16)); y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16))
assert type(y2) == numpy.int32 assert type(y2) == numpy.int32
if (y2 < 0 or y2 >= M2): if (y2 < 0 or y2 >= M2):
y2 -= M2; y2 -= M2
y2 += x23; y2 += x23
if (y2 < 0 or y2 >= M2): if (y2 < 0 or y2 >= M2):
y2 -= M2; y2 -= M2
y2 += y1; y2 += y1
if (y2 < 0 or y2 >= M2): if (y2 < 0 or y2 >= M2):
y2 -= M2; y2 -= M2
x23 = x22; x23 = x22
x22 = x21; x22 = x21
x21 = y2; x21 = y2
# Must never return either 0 or M1+1 # Must never return either 0 or M1+1
new_rstate[...] = [x11, x12, x13, x21, x22, x23] new_rstate[...] = [x11, x12, x13, x21, x22, x23]
...@@ -235,9 +234,9 @@ class mrg_uniform_base(Op): ...@@ -235,9 +234,9 @@ class mrg_uniform_base(Op):
def __init__(self, output_type, inplace=False): def __init__(self, output_type, inplace=False):
Op.__init__(self) Op.__init__(self)
self.output_type = output_type self.output_type = output_type
self.inplace=inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0:[0]} self.destroy_map = {0: [0]}
self.warned_numpy_version = False self.warned_numpy_version = False
def __eq__(self, other): def __eq__(self, other):
...@@ -289,7 +288,10 @@ class mrg_uniform(mrg_uniform_base): ...@@ -289,7 +288,10 @@ class mrg_uniform(mrg_uniform_base):
rstate, size = inp rstate, size = inp
o_rstate, o_sample = out o_rstate, o_sample = out
numpy_version = numpy.__version__.split('.') numpy_version = numpy.__version__.split('.')
if not self.warned_numpy_version and int(numpy_version[0]) <= 1 and int(numpy_version[1]) <3 : if (not self.warned_numpy_version and
int(numpy_version[0]) <= 1 and
int(numpy_version[1]) < 3):
print "Warning: you must use numpy version 1.3.0 or higher with the python version of this op. Otherwise numpy leak memory. and numpy" print "Warning: you must use numpy version 1.3.0 or higher with the python version of this op. Otherwise numpy leak memory. and numpy"
self.warned_numpy_version = True self.warned_numpy_version = True
...@@ -315,8 +317,9 @@ class mrg_uniform(mrg_uniform_base): ...@@ -315,8 +317,9 @@ class mrg_uniform(mrg_uniform_base):
finally: finally:
numpy.seterr(**err_orig) numpy.seterr(**err_orig)
o_rstate[0] = node.outputs[0].type.filter(rstate) # send to GPU if necessary # send to GPU if necessary
o_sample[0] = node.outputs[1].type.filter(rval.reshape(size)) # send to GPU if necessary o_rstate[0] = node.outputs[0].type.filter(rstate)
o_sample[0] = node.outputs[1].type.filter(rval.reshape(size))
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
rstate, size = inp rstate, size = inp
...@@ -718,7 +721,7 @@ def guess_n_streams(size, warn=True): ...@@ -718,7 +721,7 @@ def guess_n_streams(size, warn=True):
for s in size: for s in size:
r *= s r *= s
if r > 6: if r > 6:
r = r // 6 # chosen as fastest for rbm_benchmark r = r // 6 # chosen as fastest for rbm_benchmark
# The purpose of sampling from many streams is to be able to use # The purpose of sampling from many streams is to be able to use
# the GPU to its full capacity. It just wastes RAM and stream-initialization time to # the GPU to its full capacity. It just wastes RAM and stream-initialization time to
...@@ -731,8 +734,8 @@ def guess_n_streams(size, warn=True): ...@@ -731,8 +734,8 @@ def guess_n_streams(size, warn=True):
else: else:
if warn: if warn:
warnings.warn(( warnings.warn((
"MRG_RandomStreams Can't determine #streams from " "MRG_RandomStreams Can't determine #streams from "
"size (%s), guessing 60*256") % str(size), "size (%s), guessing 60*256") % str(size),
stacklevel=3) stacklevel=3)
return 60 * 256 return 60 * 256
...@@ -848,7 +851,8 @@ class MRG_RandomStreams(object): ...@@ -848,7 +851,8 @@ class MRG_RandomStreams(object):
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i, (numpy.integer, int, Variable)) assert all([isinstance(i, (numpy.integer, int, Variable))
for i in size]), msg for i in size]), msg
if any([isinstance(i, (numpy.integer, int)) and i <= 0 for i in size]): if any([isinstance(i, (numpy.integer, int)) and i <= 0
for i in size]):
raise ValueError( raise ValueError(
"The specified size contains a dimension with value <= 0", "The specified size contains a dimension with value <= 0",
size) size)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论