提交 a4192c9c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2157 from sebastien-j/issue_2076

Set limit on n_streams (Issue 2076)
......@@ -734,6 +734,13 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
unsigned int threads_per_block = std::min((unsigned int)n_streams_used_in_this_call, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
unsigned int n_blocks = std::min(ceil_intdiv((unsigned int)n_streams_used_in_this_call, threads_per_block), (unsigned int)NUM_VECTOR_OP_BLOCKS);
if (n_streams > (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK * (unsigned int)NUM_VECTOR_OP_BLOCKS)
{
PyErr_Format(PyExc_ValueError, "On GPU, n_streams should be at most %%u",
(unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK * (unsigned int)NUM_VECTOR_OP_BLOCKS);
%(fail)s;
}
if (threads_per_block * n_blocks < n_streams)
{
if (! %(nodename)s_printed_warning)
......@@ -761,7 +768,7 @@ class GPU_mrg_uniform(mrg_uniform_base, GpuOp):
""" % locals()
def c_code_cache_version(self):
return (8,)
return (9,)
class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
......
......@@ -17,6 +17,7 @@ import unittest
from theano.tests import unittest_tools as utt
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import assert_raises
#TODO: test gpu
# Done in test_consistency_GPU_{serial,parallel}
......@@ -305,6 +306,22 @@ def test_consistency_GPU_parallel():
samples = numpy.array(samples).flatten()
assert(numpy.allclose(samples, java_samples))
def test_GPU_nstreams_limit():
"""Verify that a ValueError is raised when n_streams
is greater than 2**20 on GPU. This is the value of
(NUM_VECTOR_OP_THREADS_PER_BLOCK * NUM_VECTOR_OP_BLOCKS).
"""
if not cuda_available:
raise SkipTest('Optional package cuda not available')
seed = 12345
R = MRG_RandomStreams(seed=seed, use_cuda=True)
def eval_uniform(size, nstreams):
return R.uniform(size=size, nstreams=nstreams, dtype='float32').eval()
eval_uniform((10,), 2**20)
assert_raises(ValueError, eval_uniform, (10,), 2**20 + 1)
def test_consistency_GPUA_serial():
'''Verify that the random numbers generated by GPUA_mrg_uniform, serially,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论