提交 4fd12c42 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

rng_curand.py has been modified in order to respect the flake8 style.

上级 54fe4a7f
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
import numpy import numpy
import theano.gof import theano.gof
from theano.compat import PY3 from theano.compat import PY3
...@@ -17,12 +7,21 @@ from theano.tensor import (get_vector_length, cast, opt) ...@@ -17,12 +7,21 @@ from theano.tensor import (get_vector_length, cast, opt)
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer, Variable from theano.gof import local_optimizer, Variable
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
config = theano.config config = theano.config
class CURAND_Base(GpuOp): class CURAND_Base(GpuOp):
""" """
Base class for a random number generator implemented in CURAND. Base class for a random number generator implemented in CURAND.
The random number generator itself is an opaque reference managed by The random number generator itself is an opaque reference managed by
...@@ -70,8 +69,7 @@ class CURAND_Base(GpuOp): ...@@ -70,8 +69,7 @@ class CURAND_Base(GpuOp):
Return a tuple of attributes that define the Op. Return a tuple of attributes that define the Op.
""" """
return ( return (self.destructive,
self.destructive,
self.output_type, self.output_type,
self.seed, self.seed,
) )
...@@ -88,7 +86,7 @@ class CURAND_Base(GpuOp): ...@@ -88,7 +86,7 @@ class CURAND_Base(GpuOp):
def make_node(self, generator, size): def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size], return theano.gof.Apply(self, [generator, size],
[generator.type(), self.output_type()]) [generator.type(), self.output_type()])
@classmethod @classmethod
def new_auto_update(cls, generator, ndim, dtype, size, seed): def new_auto_update(cls, generator, ndim, dtype, size, seed):
...@@ -101,10 +99,9 @@ class CURAND_Base(GpuOp): ...@@ -101,10 +99,9 @@ class CURAND_Base(GpuOp):
v_size = theano.tensor.as_tensor_variable(size) v_size = theano.tensor.as_tensor_variable(size)
if ndim is None: if ndim is None:
ndim = get_vector_length(v_size) ndim = get_vector_length(v_size)
self = cls( self = cls(output_type=CudaNdarrayType((False,) * ndim),
output_type=CudaNdarrayType((False,) * ndim), seed=seed,
seed=seed, destructive=False)
destructive=False)
o_gen, sample = self(generator, cast(v_size, 'int32')) o_gen, sample = self(generator, cast(v_size, 'int32'))
...@@ -282,7 +279,7 @@ class CURAND_RandomStreams(object): ...@@ -282,7 +279,7 @@ class CURAND_RandomStreams(object):
RandomStreams instance that creates CURAND-based random variables. RandomStreams instance that creates CURAND-based random variables.
One caveat is that generators are not serializable. One caveat is that generators are not serializable.
Parameters Parameters
---------- ----------
seed : int seed : int
...@@ -319,7 +316,7 @@ class CURAND_RandomStreams(object): ...@@ -319,7 +316,7 @@ class CURAND_RandomStreams(object):
return rval return rval
def uniform(self, size, low=0.0, high=1.0, ndim=None, def uniform(self, size, low=0.0, high=1.0, ndim=None,
dtype=config.floatX): dtype=config.floatX):
""" """
Return symbolic tensor of uniform numbers. Return symbolic tensor of uniform numbers.
...@@ -327,14 +324,14 @@ class CURAND_RandomStreams(object): ...@@ -327,14 +324,14 @@ class CURAND_RandomStreams(object):
if isinstance(size, tuple): if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i, int) or isinstance(i, Variable) assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg for i in size]), msg
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim == 1, msg assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
self.state_updates.append(u.update) self.state_updates.append(u.update)
rval = u * (high - low) + low rval = u * (high - low) + low
if u.type.broadcastable != rval.type.broadcastable: if u.type.broadcastable != rval.type.broadcastable:
...@@ -342,10 +339,10 @@ class CURAND_RandomStreams(object): ...@@ -342,10 +339,10 @@ class CURAND_RandomStreams(object):
'Increase the size to match the broadcasting pattern of ' 'Increase the size to match the broadcasting pattern of '
'low and `high` arguments' 'low and `high` arguments'
) )
return rval return rval
def normal(self, size=None, avg=0.0, std=1.0, ndim=None, def normal(self, size=None, avg=0.0, std=1.0, ndim=None,
dtype=config.floatX): dtype=config.floatX):
""" """
Return symbolic tensor of normally-distributed numbers. Return symbolic tensor of normally-distributed numbers.
...@@ -359,14 +356,14 @@ class CURAND_RandomStreams(object): ...@@ -359,14 +356,14 @@ class CURAND_RandomStreams(object):
if isinstance(size, tuple): if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i, int) or isinstance(i, Variable) assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg for i in size]), msg
else: else:
msg = "size must be a tuple of int or a Theano variable" msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim == 1, msg assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size) s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size, u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed()) self.next_seed())
self.state_updates.append(u.update) self.state_updates.append(u.update)
rval = u * std + avg rval = u * std + avg
if u.type.broadcastable != rval.type.broadcastable: if u.type.broadcastable != rval.type.broadcastable:
...@@ -374,7 +371,7 @@ class CURAND_RandomStreams(object): ...@@ -374,7 +371,7 @@ class CURAND_RandomStreams(object):
'Increase the size to match the broadcasting pattern of `low`' 'Increase the size to match the broadcasting pattern of `low`'
'and `high` arguments' 'and `high` arguments'
) )
return rval return rval
@local_optimizer([CURAND_Base]) @local_optimizer([CURAND_Base])
...@@ -386,5 +383,5 @@ def local_destructive(node): ...@@ -386,5 +383,5 @@ def local_destructive(node):
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
optdb.register('CURAND_destructive', optdb.register('CURAND_destructive',
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run', opt.in2out(local_destructive, ignore_newtrees=True),
'inplace') 99, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论