提交 4a98bd66 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the device='cuda' case.

上级 a08bba5b
......@@ -42,7 +42,10 @@ class NVCC_compiler(NVCC_base):
if dev.startswith("opencl"):
raise Exception, "Trying to call nvcc with an OpenCL context"
assert dev.startswith('cuda')
n = int(dev[4:])
if dev == 'cuda':
n = theano.sandbox.cuda.use.device_number
else:
n = int(dev[4:])
p = theano.sandbox.cuda.device_properties(n)
flags.append('-arch=sm_' + str(p['major']) + str(p['minor']))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论