提交 605da818 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change the way the cudnn version is printed and also show how much memory is preallocated.

上级 57f5c7c7
......@@ -55,37 +55,39 @@ def init_dev(dev, name=None):
disable_alloc_cache=config.gpuarray.preallocate < 0)
init_dev.devmap[dev] = ctx
if config.gpuarray.preallocate > 0:
MB = (1024 * 1024)
if config.gpuarray.preallocate < 1:
gmem = min(config.gpuarray.preallocate, 0.95) * ctx.total_gmem
else:
gmem = config.gpuarray.preallocate * (1024 * 1024)
gmem = config.gpuarray.preallocate * MB
# This will allocate and immediatly free an object of size gmem
# which will reserve that amount of memory on the GPU.
pygpu.empty((gmem,), dtype='int8', context=ctx)
if config.print_active_device:
print("Preallocating %d/%d Mb (%f) on %s" %
(gmem//MB, ctx.total_gmem//MB, gmem/ctx.total_gmem, dev),
file=sys.stderr)
context = init_dev.devmap[dev]
# This will map the context name to the real context object.
reg_context(name, context)
pygpu_activated = True
if config.print_active_device:
warn = None
cudnn_version = ""
if dev.startswith('cuda'):
cudnn_version = " (cuDNN not available)"
try:
cudnn_version = dnn.version()
# 5100 should not print warning with cudnn 5 final.
if cudnn_version > 5100:
warn = ("Your cuDNN version is more recent than Theano."
" If you see problems, try updating Theano or"
" downgrading cuDNN to version 5.")
cudnn_version = " (cuDNN version %s)" % cudnn_version
except Exception:
cudnn_version = dnn.dnn_present.msg
print("Mapped name %s to device %s: %s%s" % (
name, dev, context.devname, cudnn_version),
print("Mapped name %s to device %s: %s" %
(name, dev, context.devname),
file=sys.stderr)
if warn:
warnings.warn(warn)
pygpu_activated = True
if dev.startswith('cuda'):
try:
cudnn_version = dnn.version()
# 5100 should not print warning with cudnn 5 final.
if cudnn_version > 5100:
warnings.warn("Your cuDNN version is more recent than Theano."
" If you see problems, try updating Theano or"
" downgrading cuDNN to version 5.")
if config.print_active_device:
print("Using cuDNN version %d on context %s" %
(cudnn_version, name), file=sys.stderr)
except Exception:
pass
# This maps things like 'cuda0' to the context object on that device.
init_dev.devmap = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论