提交 ffa4e84c authored 作者: nouiz's avatar nouiz

Merge pull request #63 from dwf/curand_rng_pep8

Curand rng pep8
......@@ -2,10 +2,10 @@
Define CURAND_RandomStreams - backed by CURAND
"""
__authors__ = "James Bergstra"
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
import sys
import numpy
......@@ -13,16 +13,18 @@ import theano.gof
from theano.sandbox.cuda import CudaNdarrayType
from theano.tensor import (get_vector_length, cast, opt)
from theano.compile import optdb
from theano.gof import local_optimizer
from theano.gof import local_optimizer, Variable
config = theano.config
class CURAND_Base(theano.gof.Op):
""" Base class for a random number generator implemented in CURAND.
The random number generator itself is an opaque reference managed by CURAND.
This Op uses a generic-typed shared variable to point to a CObject that
encapsulates this opaque reference.
The random number generator itself is an opaque reference managed by
CURAND. This Op uses a generic-typed shared variable to point to a CObject
that encapsulates this opaque reference.
Each random variable is created with a generator of False.
The actual random number generator is allocated from the seed, on the first
......@@ -30,8 +32,8 @@ class CURAND_Base(theano.gof.Op):
:note:
One caveat is that the random number state is simply not serializable.
Consequently, attempts to serialize functions compiled with these random
numbers will fail.
Consequently, attempts to serialize functions compiled with these
random numbers will fail.
"""
def __init__(self, output_type, seed, destructive):
......@@ -60,13 +62,14 @@ class CURAND_Base(theano.gof.Op):
)
def __eq__(self, other):
return type(self)==type(other) and self._config()==other._config()
return type(self) == type(other) and self._config() == other._config()
def __hash__(self):
return hash((type(self), self._config()))
def __str__(self):
return self.__class__.__name__+"{inplace=%s, out_dtype=%s}"%(
self.destructive, self.output_type)
return (self.__class__.__name__ + "{inplace=%s, out_dtype=%s}" %
(self.destructive, self.output_type))
def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size],
......@@ -90,9 +93,9 @@ class CURAND_Base(theano.gof.Op):
o_gen, sample = self(generator, cast(v_size, 'int32'))
sample.generator = generator # for user
sample.update = (generator, o_gen) # for CURAND_RandomStreams
generator.default_update = o_gen # for pfunc uses this attribute
sample.generator = generator # for user
sample.update = (generator, o_gen) # for CURAND_RandomStreams
generator.default_update = o_gen # for pfunc uses this attribute
return sample
def c_headers(self):
......@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op):
cudaThreadSynchronize();
}
//////// </ code generated by CURAND_Base>
""" %locals()
""" % locals()
def c_code_cache_version(self):
return (2,)
......@@ -230,7 +233,7 @@ class CURAND_Normal(CURAND_Base):
CudaNdarray_DEV_DATA(%(o_sample)s),
n_elements,
0.0, 1.0);
"""%kwargs
""" % kwargs
class CURAND_Uniform(CURAND_Base):
......@@ -240,7 +243,7 @@ class CURAND_Uniform(CURAND_Base):
return """ curandGenerateUniform(*gen,
CudaNdarray_DEV_DATA(%(o_sample)s),
n_elements);
"""%kwargs
""" % kwargs
class CURAND_RandomStreams(object):
......@@ -254,7 +257,7 @@ class CURAND_RandomStreams(object):
"""
self._start_seed = seed
self._cur_seed = seed
self._has_lost_states = False #True if self.state_updates is incomplete
self._has_lost_states = False # True if self.state_updates incomplete
self.state_updates = []
def updates(self):
......@@ -267,7 +270,7 @@ class CURAND_RandomStreams(object):
"""Return a unique seed for initializing a random variable.
"""
self._cur_seed += 1
return self._cur_seed -1
return self._cur_seed - 1
def __getstate__(self):
rval = dict(self.__dict__)
......@@ -283,19 +286,22 @@ class CURAND_RandomStreams(object):
"""
if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i,int) or isinstance(i,Variable)
assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(False) #makes a generic
assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
self.state_updates.append(u.update)
rval = u * (high-low) + low
rval = u * (high - low) + low
if u.type.broadcastable != rval.type.broadcastable:
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments')
raise NotImplementedError(
'Increase the size to match the broadcasting pattern of '
'low and `high` arguments'
)
return rval
def normal(self, size=None, avg=0.0, std=1.0, ndim=None,
......@@ -303,23 +309,27 @@ class CURAND_RandomStreams(object):
"""
Return symbolic tensor of normally-distributed numbers.
:param: size: Can be a list of integer or Theano variable(ex: the shape of other Theano Variable)
:param: size: Can be a list of integer or Theano variable(ex: the shape
of other Theano Variable)
"""
if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i,int) or isinstance(i,Variable)
assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg
generator = theano.shared(False) #makes a generic
assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
self.state_updates.append(u.update)
rval = u * std + avg
if u.type.broadcastable != rval.type.broadcastable:
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments')
raise NotImplementedError(
'Increase the size to match the broadcasting pattern of `low`'
'and `high` arguments'
)
return rval
......@@ -332,4 +342,5 @@ def local_destructive(node):
return new_op.make_node(*node.inputs).outputs
return False
optdb.register('CURAND_destructive',
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run', 'inplace')
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run',
'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论