提交 adb7ffd3 authored 作者: James Bergstra's avatar James Bergstra

ENH: commenting cap on the number of streams in MRG, and raising it.

上级 a5fe2dd0
...@@ -605,7 +605,7 @@ def guess_n_streams(size, warn=True): ...@@ -605,7 +605,7 @@ def guess_n_streams(size, warn=True):
Return a guess at a good number of streams. Return a guess at a good number of streams.
:param warn: If True, warn when a guess cannot be made (in which case :param warn: If True, warn when a guess cannot be made (in which case
we return 30 * 256). we return 60 * 256).
""" """
# TODO: a smart way of choosing the number of streams, see #612. # TODO: a smart way of choosing the number of streams, see #612.
# Note that this code was moved out of `MRG_RandomStreams` so that it can # Note that this code was moved out of `MRG_RandomStreams` so that it can
...@@ -618,14 +618,22 @@ def guess_n_streams(size, warn=True): ...@@ -618,14 +618,22 @@ def guess_n_streams(size, warn=True):
r *= s r *= s
if r > 6: if r > 6:
r = r/6 # chosen as fastest for rbm_benchmark r = r/6 # chosen as fastest for rbm_benchmark
return min(r, 30 * 256)
# The purpose of sampling from many streams is to be able to use
# the GPU to its full capacity. It just wastes RAM and stream-initialization time to
# allocate more streams than necessary for the GPU.
# XXX: This number is chosen to be good for 280 and 480 architectures,
# Better would be to use pycuda to query the number of
# processors on the GPU device,
# rather than guessing 60.
return min(r, 60 * 256)
else: else:
if warn: if warn:
warnings.warn(( warnings.warn((
"MRG_RandomStreams Can't determine #streams from " "MRG_RandomStreams Can't determine #streams from "
"size (%s), guessing 30*256") % str(size), "size (%s), guessing 60*256") % str(size),
stacklevel=3) stacklevel=3)
return 30 * 256 return 60 * 256
class MRG_RandomStreams(object): class MRG_RandomStreams(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论