提交 4fd12c42 authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

rng_curand.py has been modified in order to respect the flake8 style.

上级 54fe4a7f
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
import numpy import numpy
import theano.gof import theano.gof
from theano.compat import PY3 from theano.compat import PY3
...@@ -17,6 +7,15 @@ from theano.tensor import (get_vector_length, cast, opt) ...@@ -17,6 +7,15 @@ from theano.tensor import (get_vector_length, cast, opt)
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer, Variable from theano.gof import local_optimizer, Variable
__authors__ = "James Bergstra"
__copyright__ = "(c) 2011, University of Montreal"
__license__ = "3-clause BSD License"
__contact__ = "theano-dev@googlegroups.com"
"""
Define CURAND_RandomStreams - backed by CURAND.
"""
config = theano.config config = theano.config
...@@ -70,8 +69,7 @@ class CURAND_Base(GpuOp): ...@@ -70,8 +69,7 @@ class CURAND_Base(GpuOp):
Return a tuple of attributes that define the Op. Return a tuple of attributes that define the Op.
""" """
return ( return (self.destructive,
self.destructive,
self.output_type, self.output_type,
self.seed, self.seed,
) )
...@@ -101,8 +99,7 @@ class CURAND_Base(GpuOp): ...@@ -101,8 +99,7 @@ class CURAND_Base(GpuOp):
v_size = theano.tensor.as_tensor_variable(size) v_size = theano.tensor.as_tensor_variable(size)
if ndim is None: if ndim is None:
ndim = get_vector_length(v_size) ndim = get_vector_length(v_size)
self = cls( self = cls(output_type=CudaNdarrayType((False,) * ndim),
output_type=CudaNdarrayType((False,) * ndim),
seed=seed, seed=seed,
destructive=False) destructive=False)
...@@ -386,5 +383,5 @@ def local_destructive(node): ...@@ -386,5 +383,5 @@ def local_destructive(node):
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
optdb.register('CURAND_destructive', optdb.register('CURAND_destructive',
opt.in2out(local_destructive, ignore_newtrees=True), 99, 'fast_run', opt.in2out(local_destructive, ignore_newtrees=True),
'inplace') 99, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论