提交 28e0c795 authored 作者: Frederic's avatar Frederic

Always put an int for the gpu device number initialized.

上级 32ae3995
......@@ -309,8 +309,15 @@ def use(device,
device = 0
try:
if device != 'gpu':
assert isinstance(device, int)
gpu_init(device)
use.device_number = device
use.device_number = device
else:
# This mean we let the driver select the GPU.
# But default it is always number 0.
# If the driver is in exclusive mode, it will always show
# device 0 event if it use something else.
use.device_number = 0
if test_driver:
import theano.sandbox.cuda.tests.test_driver
theano.sandbox.cuda.tests.test_driver.test_nvidia_driver1()
......
......@@ -718,10 +718,11 @@ class GpuConv(GpuOp):
node_ = copy.copy(node)
assert node.op is node_.op
if node_.op.max_threads_dim0 is None:
op = copy.copy(node_.op)
device_id = theano.sandbox.cuda.use.device_number[3:]
if device_id == '':
device_id = 0
cuda = theano.sandbox.cuda
device_id = cuda.use.device_number
if device_id is None:
cuda.use("gpu", False, False, False, False, True)
device_id = cuda.use.device_number
cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
prop = cuda_ndarray.device_properties(device_id)
node_.op.max_threads_dim0 = prop['maxThreadsDim0']
......
......@@ -35,9 +35,9 @@ device_id = theano.sandbox.cuda.use.device_number
if device_id is None:
cuda_ndarray.shared_constructor(numpy.zeros(2, dtype='float32'))
device_id = theano.sandbox.cuda.use.device_number
device_id = device_id[3:]
if device_id == '':
device_id = 0
if device_id is None:
cuda.use("gpu", False, False, False, False, True)
device_id = theano.sandbox.cuda.use.device_number
cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
device_prop = cuda_ndarray.device_properties(device_id)
......@@ -55,7 +55,7 @@ def py_conv_valid_numpy(img, kern):
#rr, cc is the upper-left corner of img patches
imgpatch = img[b, :, rr:rr + kern.shape[2],
cc:cc + kern.shape[3]]
#print img.shape, kern.shape, imgpatch.shape, rr+kern.shape[2]-1, rr-1, -1
innerprod = (imgpatch[:, ::-1, ::-1] *
kern[k, :, :, :]).sum()
out[b, k, rr, cc] = innerprod
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论