提交 47ea0388 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/sandbox/test_rng_mrg.py

上级 8befdc6c
...@@ -75,10 +75,11 @@ def test_deterministic(): ...@@ -75,10 +75,11 @@ def test_deterministic():
def test_consistency_randomstreams(): def test_consistency_randomstreams():
'''Verify that the random numbers generated by MRG_RandomStreams """
Verify that the random numbers generated by MRG_RandomStreams
are the same as the reference (Java) implementation by L'Ecuyer et al. are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
"""
seed = 12345 seed = 12345
n_samples = 5 n_samples = 5
n_streams = 12 n_streams = 12
...@@ -108,9 +109,11 @@ def test_consistency_randomstreams(): ...@@ -108,9 +109,11 @@ def test_consistency_randomstreams():
def test_consistency_cpu_serial(): def test_consistency_cpu_serial():
'''Verify that the random numbers generated by mrg_uniform, serially, """
Verify that the random numbers generated by mrg_uniform, serially,
are the same as the reference (Java) implementation by L'Ecuyer et al. are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
"""
seed = 12345 seed = 12345
n_samples = 5 n_samples = 5
n_streams = 12 n_streams = 12
...@@ -149,9 +152,11 @@ def test_consistency_cpu_serial(): ...@@ -149,9 +152,11 @@ def test_consistency_cpu_serial():
def test_consistency_cpu_parallel(): def test_consistency_cpu_parallel():
'''Verify that the random numbers generated by mrg_uniform, in parallel, """
Verify that the random numbers generated by mrg_uniform, in parallel,
are the same as the reference (Java) implementation by L'Ecuyer et al. are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
"""
seed = 12345 seed = 12345
n_samples = 5 n_samples = 5
n_streams = 12 n_streams = 12
...@@ -193,9 +198,11 @@ def test_consistency_cpu_parallel(): ...@@ -193,9 +198,11 @@ def test_consistency_cpu_parallel():
def test_consistency_GPU_serial(): def test_consistency_GPU_serial():
'''Verify that the random numbers generated by GPU_mrg_uniform, serially, """
Verify that the random numbers generated by GPU_mrg_uniform, serially,
are the same as the reference (Java) implementation by L'Ecuyer et al. are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
"""
if not cuda_available: if not cuda_available:
raise SkipTest('Optional package cuda not available') raise SkipTest('Optional package cuda not available')
if config.mode == 'FAST_COMPILE': if config.mode == 'FAST_COMPILE':
...@@ -250,11 +257,12 @@ def test_consistency_GPU_serial(): ...@@ -250,11 +257,12 @@ def test_consistency_GPU_serial():
def test_consistency_GPU_parallel(): def test_consistency_GPU_parallel():
'''Verify that the random numbers generated by GPU_mrg_uniform, in """
Verify that the random numbers generated by GPU_mrg_uniform, in
parallel, are the same as the reference (Java) implementation by parallel, are the same as the reference (Java) implementation by
L'Ecuyer et al. L'Ecuyer et al.
''' """
if not cuda_available: if not cuda_available:
raise SkipTest('Optional package cuda not available') raise SkipTest('Optional package cuda not available')
if config.mode == 'FAST_COMPILE': if config.mode == 'FAST_COMPILE':
...@@ -310,9 +318,11 @@ def test_consistency_GPU_parallel(): ...@@ -310,9 +318,11 @@ def test_consistency_GPU_parallel():
def test_GPU_nstreams_limit(): def test_GPU_nstreams_limit():
"""Verify that a ValueError is raised when n_streams """
Verify that a ValueError is raised when n_streams
is greater than 2**20 on GPU. This is the value of is greater than 2**20 on GPU. This is the value of
(NUM_VECTOR_OP_THREADS_PER_BLOCK * NUM_VECTOR_OP_BLOCKS). (NUM_VECTOR_OP_THREADS_PER_BLOCK * NUM_VECTOR_OP_BLOCKS).
""" """
if not cuda_available: if not cuda_available:
raise SkipTest('Optional package cuda not available') raise SkipTest('Optional package cuda not available')
...@@ -335,9 +345,11 @@ def test_GPU_nstreams_limit(): ...@@ -335,9 +345,11 @@ def test_GPU_nstreams_limit():
def test_consistency_GPUA_serial(): def test_consistency_GPUA_serial():
'''Verify that the random numbers generated by GPUA_mrg_uniform, serially, """
Verify that the random numbers generated by GPUA_mrg_uniform, serially,
are the same as the reference (Java) implementation by L'Ecuyer et al. are the same as the reference (Java) implementation by L'Ecuyer et al.
'''
"""
from theano.sandbox.gpuarray.tests.test_basic_ops import \ from theano.sandbox.gpuarray.tests.test_basic_ops import \
mode_with_gpu as mode mode_with_gpu as mode
from theano.sandbox.gpuarray.type import gpuarray_shared_constructor from theano.sandbox.gpuarray.type import gpuarray_shared_constructor
...@@ -387,11 +399,12 @@ def test_consistency_GPUA_serial(): ...@@ -387,11 +399,12 @@ def test_consistency_GPUA_serial():
def test_consistency_GPUA_parallel(): def test_consistency_GPUA_parallel():
'''Verify that the random numbers generated by GPUA_mrg_uniform, in """
Verify that the random numbers generated by GPUA_mrg_uniform, in
parallel, are the same as the reference (Java) implementation by parallel, are the same as the reference (Java) implementation by
L'Ecuyer et al. L'Ecuyer et al.
''' """
from theano.sandbox.gpuarray.tests.test_basic_ops import \ from theano.sandbox.gpuarray.tests.test_basic_ops import \
mode_with_gpu as mode mode_with_gpu as mode
from theano.sandbox.gpuarray.type import gpuarray_shared_constructor from theano.sandbox.gpuarray.type import gpuarray_shared_constructor
...@@ -855,6 +868,7 @@ def test_multiple_rng_aliasing(): ...@@ -855,6 +868,7 @@ def test_multiple_rng_aliasing():
copy the (random) state between two similar theano graphs. The test is copy the (random) state between two similar theano graphs. The test is
meant to detect a previous bug where state_updates was initialized as a meant to detect a previous bug where state_updates was initialized as a
class-attribute, instead of the __init__ function. class-attribute, instead of the __init__ function.
""" """
rng1 = MRG_RandomStreams(1234) rng1 = MRG_RandomStreams(1234)
rng2 = MRG_RandomStreams(2392) rng2 = MRG_RandomStreams(2392)
...@@ -864,6 +878,7 @@ def test_multiple_rng_aliasing(): ...@@ -864,6 +878,7 @@ def test_multiple_rng_aliasing():
def test_random_state_transfer(): def test_random_state_transfer():
""" """
Test that random state can be transferred from one theano graph to another. Test that random state can be transferred from one theano graph to another.
""" """
class Graph: class Graph:
def __init__(self, seed=123): def __init__(self, seed=123):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论