提交 711f0835 authored 作者: Frederic Bastien's avatar Frederic Bastien

make neighbours gpu optimizer work correctly and fix the gpu make_node of this op.

上级 8159833f
...@@ -3,11 +3,12 @@ from theano import Op, Apply ...@@ -3,11 +3,12 @@ from theano import Op, Apply
import theano.tensor as T import theano.tensor as T
from theano.tensor.opt import register_specialize from theano.tensor.opt import register_specialize
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.sandbox.cuda import cuda_available from theano.sandbox.cuda import cuda_available
if cuda_available: if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType from theano.sandbox.cuda import CudaNdarrayType
from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host
from theano.sandbox.cuda.opt import register_opt as register_gpu_opt
class Images2Neibs(Op): class Images2Neibs(Op):
def __eq__(self, other): def __eq__(self, other):
...@@ -163,7 +164,7 @@ class GpuImages2Neibs(Images2Neibs): ...@@ -163,7 +164,7 @@ class GpuImages2Neibs(Images2Neibs):
# raise TypeError('unis must be cudandarray', neib_shape) # raise TypeError('unis must be cudandarray', neib_shape)
#print 'neib_shape type and dtype', type(neib_shape), neib_shape.dtype #print 'neib_shape type and dtype', type(neib_shape), neib_shape.dtype
return Apply(self, [ten4, neib_shape], [CudaNdarrayType(broadcastable=(False,)*2)()]) return Apply(self, [ten4, neib_shape], [ten4.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
...@@ -360,6 +361,7 @@ gpu_images2neibs = GpuImages2Neibs() ...@@ -360,6 +361,7 @@ gpu_images2neibs = GpuImages2Neibs()
def use_gpu_images2neibs(node): def use_gpu_images2neibs(node):
if node.op == images2neibs: if node.op == images2neibs:
return [host_from_gpu(gpu_images2neibs(*[gpu_from_host(node.inputs[0]),node.inputs[1]]))] return [host_from_gpu(gpu_images2neibs(*[gpu_from_host(node.inputs[0]),node.inputs[1]]))]
if theano.config.device.startswith('gpu'):
register_specialize(use_gpu_images2neibs) if cuda_available:
register_gpu_opt()(use_gpu_images2neibs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论