提交 85375e23 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: Frederic Bastien

Move the preallocation to after creating the cudnn handle.

Also reserve 50M of device memory for other such handles.
上级 109b378c
......@@ -47,6 +47,7 @@ def init_dev(dev, name=None):
raise RuntimeError("The new gpu-backend need a c++ compiler.")
if (pygpu.version.major, pygpu.version.minor) < (0, 6):
raise ValueError("Your installed version of pygpu is too old, please upgrade to 0.6 or later")
need_preallocate = False
if dev not in init_dev.devmap:
ctx = pygpu.init(dev,
disable_alloc_cache=config.gpuarray.preallocate < 0,
......@@ -56,18 +57,7 @@ def init_dev(dev, name=None):
if config.gpuarray.preallocate < 0:
print("Disabling allocation cache on %s" % (dev,))
elif config.gpuarray.preallocate > 0:
MB = (1024 * 1024)
if config.gpuarray.preallocate <= 1:
gmem = min(config.gpuarray.preallocate, 0.95) * ctx.total_gmem
else:
gmem = config.gpuarray.preallocate * MB
# This will allocate and immediatly free an object of size gmem
# which will reserve that amount of memory on the GPU.
pygpu.empty((gmem,), dtype='int8', context=ctx)
if config.print_active_device:
print("Preallocating %d/%d Mb (%f) on %s" %
(gmem//MB, ctx.total_gmem//MB, gmem/ctx.total_gmem, dev),
file=sys.stderr)
need_preallocate = True
context = init_dev.devmap[dev]
# This will map the context name to the real context object.
reg_context(name, context)
......@@ -103,6 +93,21 @@ def init_dev(dev, name=None):
ctx_props['cudnn_handle'] = dnn._make_handle(context)
except Exception:
pass
if need_preallocate:
MB = (1024 * 1024)
if config.gpuarray.preallocate <= 1:
gmem = min(config.gpuarray.preallocate, 0.95) * ctx.total_gmem
else:
gmem = config.gpuarray.preallocate * MB
gmem = min(ctx.free_gmem - 50 * MB, gmem)
# This will allocate and immediatly free an object of size gmem
# which will reserve that amount of memory on the GPU.
pygpu.empty((gmem,), dtype='int8', context=ctx)
if config.print_active_device:
print("Preallocating %d/%d Mb (%f) on %s" %
(gmem//MB, ctx.total_gmem//MB, gmem/ctx.total_gmem, dev),
file=sys.stderr)
# This maps things like 'cuda0' to the context object on that device.
init_dev.devmap = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论