提交 8faa9d9b authored 作者: James Bergstra's avatar James Bergstra

added global var to rng_mrg to suppress repeated warnings.

上级 6dd3ebf9
...@@ -376,6 +376,7 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -376,6 +376,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
otype = 'double' otype = 'double'
NORM = '4.656612873077392578125e-10' NORM = '4.656612873077392578125e-10'
return """ return """
static int %(nodename)s_printed_warning = 0;
static __global__ void %(nodename)s_mrg_uniform( static __global__ void %(nodename)s_mrg_uniform(
%(otype)s*sample_data, %(otype)s*sample_data,
...@@ -543,7 +544,9 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -543,7 +544,9 @@ class GPU_mrg_uniform(mrg_uniform_base):
if (threads_per_block * n_blocks < n_streams) if (threads_per_block * n_blocks < n_streams)
{ {
fprintf(stderr, "WARNING: unused streams above %%i (Tune GPU_mrg get_n_streams)\\n", threads_per_block * n_blocks ); if (! %(nodename)s_printed_warning)
fprintf(stderr, "WARNING: unused streams above %%i (Tune GPU_mrg get_n_streams)\\n", threads_per_block * n_blocks );
%(nodename)s_printed_warning = 1;
} }
%(nodename)s_mrg_uniform<<<n_blocks,threads_per_block>>>( %(nodename)s_mrg_uniform<<<n_blocks,threads_per_block>>>(
CudaNdarray_DEV_DATA(%(o_sample)s), CudaNdarray_DEV_DATA(%(o_sample)s),
...@@ -565,7 +568,7 @@ class GPU_mrg_uniform(mrg_uniform_base): ...@@ -565,7 +568,7 @@ class GPU_mrg_uniform(mrg_uniform_base):
//////// </ code generated by mrg_uniform> //////// </ code generated by mrg_uniform>
""" %locals() """ %locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
class MRG_RandomStreams(object): class MRG_RandomStreams(object):
"""Module component with similar interface to numpy.random (numpy.random.RandomState)""" """Module component with similar interface to numpy.random (numpy.random.RandomState)"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论