提交 ffa4e84c authored 作者: nouiz's avatar nouiz

Merge pull request #63 from dwf/curand_rng_pep8

Curand rng pep8
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
Define CURAND_RandomStreams - backed by CURAND Define CURAND_RandomStreams - backed by CURAND
""" """
__authors__ = "James Bergstra" __authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal" __copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License" __license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com" __contact__ = "theano-dev@googlegroups.com"
import sys import sys
import numpy import numpy
...@@ -13,16 +13,18 @@ import theano.gof ...@@ -13,16 +13,18 @@ import theano.gof
from theano.sandbox.cuda import CudaNdarrayType from theano.sandbox.cuda import CudaNdarrayType
from theano.tensor import (get_vector_length, cast, opt) from theano.tensor import (get_vector_length, cast, opt)
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer from theano.gof import local_optimizer, Variable
config = theano.config config = theano.config
class CURAND_Base(theano.gof.Op): class CURAND_Base(theano.gof.Op):
""" Base class for a random number generator implemented in CURAND. """ Base class for a random number generator implemented in CURAND.
The random number generator itself is an opaque reference managed by CURAND. The random number generator itself is an opaque reference managed by
This Op uses a generic-typed shared variable to point to a CObject that CURAND. This Op uses a generic-typed shared variable to point to a CObject
encapsulates this opaque reference. that encapsulates this opaque reference.
Each random variable is created with a generator of False. Each random variable is created with a generator of False.
The actual random number generator is allocated from the seed, on the first The actual random number generator is allocated from the seed, on the first
...@@ -30,8 +32,8 @@ class CURAND_Base(theano.gof.Op): ...@@ -30,8 +32,8 @@ class CURAND_Base(theano.gof.Op):
:note: :note:
One caveat is that the random number state is simply not serializable. One caveat is that the random number state is simply not serializable.
Consequently, attempts to serialize functions compiled with these random Consequently, attempts to serialize functions compiled with these
numbers will fail. random numbers will fail.
""" """
def __init__(self, output_type, seed, destructive): def __init__(self, output_type, seed, destructive):
...@@ -60,13 +62,14 @@ class CURAND_Base(theano.gof.Op): ...@@ -60,13 +62,14 @@ class CURAND_Base(theano.gof.Op):
) )
def __eq__(self, other): def __eq__(self, other):
return type(self)==type(other) and self._config()==other._config() return type(self) == type(other) and self._config() == other._config()
def __hash__(self): def __hash__(self):
return hash((type(self), self._config())) return hash((type(self), self._config()))
def __str__(self): def __str__(self):
return self.__class__.__name__+"{inplace=%s, out_dtype=%s}"%( return (self.__class__.__name__ + "{inplace=%s, out_dtype=%s}" %
self.destructive, self.output_type) (self.destructive, self.output_type))
def make_node(self, generator, size): def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size], return theano.gof.Apply(self, [generator, size],
...@@ -90,9 +93,9 @@ class CURAND_Base(theano.gof.Op): ...@@ -90,9 +93,9 @@ class CURAND_Base(theano.gof.Op):
o_gen, sample = self(generator, cast(v_size, 'int32')) o_gen, sample = self(generator, cast(v_size, 'int32'))
sample.generator = generator # for user sample.generator = generator # for user
sample.update = (generator, o_gen) # for CURAND_RandomStreams sample.update = (generator, o_gen) # for CURAND_RandomStreams
generator.default_update = o_gen # for pfunc uses this attribute generator.default_update = o_gen # for pfunc uses this attribute
return sample return sample
def c_headers(self): def c_headers(self):
...@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op): ...@@ -216,7 +219,7 @@ class CURAND_Base(theano.gof.Op):
cudaThreadSynchronize(); cudaThreadSynchronize();
} }
//////// </ code generated by CURAND_Base> //////// </ code generated by CURAND_Base>
""" %locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -230,7 +233,7 @@ class CURAND_Normal(CURAND_Base): ...@@ -230,7 +233,7 @@ class CURAND_Normal(CURAND_Base):
CudaNdarray_DEV_DATA(%(o_sample)s), CudaNdarray_DEV_DATA(%(o_sample)s),
n_elements, n_elements,
0.0, 1.0); 0.0, 1.0);
"""%kwargs """ % kwargs
class CURAND_Uniform(CURAND_Base): class CURAND_Uniform(CURAND_Base):
...@@ -240,7 +243,7 @@ class CURAND_Uniform(CURAND_Base): ...@@ -240,7 +243,7 @@ class CURAND_Uniform(CURAND_Base):
return """ curandGenerateUniform(*gen, return """ curandGenerateUniform(*gen,
CudaNdarray_DEV_DATA(%(o_sample)s), CudaNdarray_DEV_DATA(%(o_sample)s),
n_elements); n_elements);
"""%kwargs """ % kwargs
class CURAND_RandomStreams(object): class CURAND_RandomStreams(object):
...@@ -254,7 +257,7 @@ class CURAND_RandomStreams(object): ...@@ -254,7 +257,7 @@ class CURAND_RandomStreams(object):
""" """
self._start_seed = seed self._start_seed = seed
self._cur_seed = seed self._cur_seed = seed
self._has_lost_states = False #True if self.state_updates is incomplete self._has_lost_states = False # True if self.state_updates incomplete
self.state_updates = [] self.state_updates = []
def updates(self): def updates(self):
...@@ -267,7 +270,7 @@ class CURAND_RandomStreams(object): ...@@ -267,7 +270,7 @@ class CURAND_RandomStreams(object):
"""Return a unique seed for initializing a random variable. """Return a unique seed for initializing a random variable.
""" """
self._cur_seed += 1 self._cur_seed += 1
return self._cur_seed -1 return self._cur_seed - 1
def __getstate__(self): def __getstate__(self):
rval = dict(self.__dict__) rval = dict(self.__dict__)
...@@ -283,19 +286,22 @@ class CURAND_RandomStreams(object): ...@@ -283,19 +286,22 @@ class CURAND_RandomStreams(object):
""" """
if isinstance(size, tuple): if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i,int) or isinstance(i,Variable) assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg for i in size]), msg
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) #makes a generic generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
self.state_updates.append(u.update) self.state_updates.append(u.update)
rval = u * (high-low) + low rval = u * (high - low) + low
if u.type.broadcastable != rval.type.broadcastable: if u.type.broadcastable != rval.type.broadcastable:
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments') raise NotImplementedError(
'Increase the size to match the broadcasting pattern of '
'low and `high` arguments'
)
return rval return rval
def normal(self, size=None, avg=0.0, std=1.0, ndim=None, def normal(self, size=None, avg=0.0, std=1.0, ndim=None,
...@@ -303,23 +309,27 @@ class CURAND_RandomStreams(object): ...@@ -303,23 +309,27 @@ class CURAND_RandomStreams(object):
""" """
Return symbolic tensor of normally-distributed numbers. Return symbolic tensor of normally-distributed numbers.
:param: size: Can be a list of integer or Theano variable(ex: the shape of other Theano Variable) :param: size: Can be a list of integer or Theano variable(ex: the shape
of other Theano Variable)
""" """
if isinstance(size, tuple): if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i,int) or isinstance(i,Variable) assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg for i in size]), msg
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim==1, msg assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) #makes a generic generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
self.state_updates.append(u.update) self.state_updates.append(u.update)
rval = u * std + avg rval = u * std + avg
if u.type.broadcastable != rval.type.broadcastable: if u.type.broadcastable != rval.type.broadcastable:
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments') raise NotImplementedError(
'Increase the size to match the broadcasting pattern of `low`'
'and `high` arguments'
)
return rval return rval
...@@ -332,4 +342,5 @@ def local_destructive(node): ...@@ -332,4 +342,5 @@ def local_destructive(node):
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
optdb.register('CURAND_destructive', optdb.register('CURAND_destructive',
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run', 'inplace') opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run',
'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论