提交 b37658f3 authored 作者: Frederic Bastien's avatar Frederic Bastien

replace a make_thunk by a prepare_node

上级 d1bfd2b0
...@@ -1948,10 +1948,8 @@ class GpuConv(GpuOp): ...@@ -1948,10 +1948,8 @@ class GpuConv(GpuOp):
images[2] * images[3] * 2) images[2] * images[3] * 2)
return flops return flops
def make_thunk(self, node, storage_map, compute_map, no_recycling): def prepare_node(self, node):
node_ = copy.copy(node) if node.op.max_threads_dim0 is None:
assert node.op is node_.op
if node_.op.max_threads_dim0 is None:
cuda = theano.sandbox.cuda cuda = theano.sandbox.cuda
device_id = cuda.use.device_number device_id = cuda.use.device_number
if device_id is None: if device_id is None:
...@@ -1964,9 +1962,7 @@ class GpuConv(GpuOp): ...@@ -1964,9 +1962,7 @@ class GpuConv(GpuOp):
device_id = cuda.use.device_number device_id = cuda.use.device_number
cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray cuda_ndarray = theano.sandbox.cuda.cuda_ndarray.cuda_ndarray
prop = cuda_ndarray.device_properties(device_id) prop = cuda_ndarray.device_properties(device_id)
node_.op.max_threads_dim0 = prop['maxThreadsDim0'] node.op.max_threads_dim0 = prop['maxThreadsDim0']
return super(GpuConv, node_.op).make_thunk(node_, storage_map,
compute_map, no_recycling)
def c_compile_args(self): def c_compile_args(self):
nb = 0 nb = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论