提交 4fd12c42 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

rng_curand.py has been modified in order to respect the flake8 style.

上级 54fe4a7f
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
from __future__ import absolute_import, print_function, division
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
import numpy
import theano.gof
from theano.compat import PY3
......@@ -17,12 +7,21 @@ from theano.tensor import (get_vector_length, cast, opt)
from theano.compile import optdb
from theano.gof import local_optimizer, Variable
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
config = theano.config
class CURAND_Base(GpuOp):
"""
"""
Base class for a random number generator implemented in CURAND.
The random number generator itself is an opaque reference managed by
......@@ -70,8 +69,7 @@ class CURAND_Base(GpuOp):
Return a tuple of attributes that define the Op.
"""
return (
self.destructive,
return (self.destructive,
self.output_type,
self.seed,
)
......@@ -88,7 +86,7 @@ class CURAND_Base(GpuOp):
def make_node(self, generator, size):
return theano.gof.Apply(self, [generator, size],
[generator.type(), self.output_type()])
[generator.type(), self.output_type()])
@classmethod
def new_auto_update(cls, generator, ndim, dtype, size, seed):
......@@ -101,10 +99,9 @@ class CURAND_Base(GpuOp):
v_size = theano.tensor.as_tensor_variable(size)
if ndim is None:
ndim = get_vector_length(v_size)
self = cls(
output_type=CudaNdarrayType((False,) * ndim),
seed=seed,
destructive=False)
self = cls(output_type=CudaNdarrayType((False,) * ndim),
seed=seed,
destructive=False)
o_gen, sample = self(generator, cast(v_size, 'int32'))
......@@ -282,7 +279,7 @@ class CURAND_RandomStreams(object):
RandomStreams instance that creates CURAND-based random variables.
One caveat is that generators are not serializable.
Parameters
----------
seed : int
......@@ -319,7 +316,7 @@ class CURAND_RandomStreams(object):
return rval
def uniform(self, size, low=0.0, high=1.0, ndim=None,
dtype=config.floatX):
dtype=config.floatX):
"""
Return symbolic tensor of uniform numbers.
......@@ -327,14 +324,14 @@ class CURAND_RandomStreams(object):
if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg
for i in size]), msg
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Uniform.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
self.next_seed())
self.state_updates.append(u.update)
rval = u * (high - low) + low
if u.type.broadcastable != rval.type.broadcastable:
......@@ -342,10 +339,10 @@ class CURAND_RandomStreams(object):
'Increase the size to match the broadcasting pattern of '
'low and `high` arguments'
)
return rval
return rval
def normal(self, size=None, avg=0.0, std=1.0, ndim=None,
dtype=config.floatX):
dtype=config.floatX):
"""
Return symbolic tensor of normally-distributed numbers.
......@@ -359,14 +356,14 @@ class CURAND_RandomStreams(object):
if isinstance(size, tuple):
msg = "size must be a tuple of int or a Theano variable"
assert all([isinstance(i, int) or isinstance(i, Variable)
for i in size]), msg
for i in size]), msg
else:
msg = "size must be a tuple of int or a Theano variable"
assert isinstance(size, Variable) and size.ndim == 1, msg
generator = theano.shared(False) # makes a generic
s_size = theano.tensor.as_tensor_variable(size)
u = CURAND_Normal.new_auto_update(generator, ndim, dtype, s_size,
self.next_seed())
self.next_seed())
self.state_updates.append(u.update)
rval = u * std + avg
if u.type.broadcastable != rval.type.broadcastable:
......@@ -374,7 +371,7 @@ class CURAND_RandomStreams(object):
'Increase the size to match the broadcasting pattern of `low`'
'and `high` arguments'
)
return rval
return rval
@local_optimizer([CURAND_Base])
......@@ -386,5 +383,5 @@ def local_destructive(node):
return new_op.make_node(*node.inputs).outputs
return False
optdb.register('CURAND_destructive',
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run',
'inplace')
opt.in2out(local_destructive, ignore_newtrees=True),
99, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论