提交 6c4bdd7d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Type context for neighbours.py.

上级 a28251c6
...@@ -10,7 +10,8 @@ try: ...@@ -10,7 +10,8 @@ try:
except ImportError: except ImportError:
pass pass
from .basic_ops import as_gpuarray_variable, GpuKernelBase, Kernel from .basic_ops import (as_gpuarray_variable, GpuKernelBase, Kernel,
infer_context_name)
from .opt import register_opt as register_gpu_opt, op_lifter from .opt import register_opt as register_gpu_opt, op_lifter
from .type import GpuArrayType from .type import GpuArrayType
...@@ -25,7 +26,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -25,7 +26,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
self.mode = mode self.mode = mode
def make_node(self, ten4, neib_shape, neib_step): def make_node(self, ten4, neib_shape, neib_step):
ten4 = as_gpuarray_variable(ten4) ten4 = as_gpuarray_variable(ten4, infer_context_name(ten4))
neib_shape = T.as_tensor_variable(neib_shape) neib_shape = T.as_tensor_variable(neib_shape)
neib_step = T.as_tensor_variable(neib_step) neib_step = T.as_tensor_variable(neib_step)
...@@ -37,7 +38,11 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -37,7 +38,11 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
return Apply(self, [ten4, neib_shape, neib_step], return Apply(self, [ten4, neib_shape, neib_step],
[GpuArrayType(broadcastable=(False, False), [GpuArrayType(broadcastable=(False, False),
dtype=ten4.type.dtype)()]) dtype=ten4.type.dtype,
context_name=ten4.type.context_name)()])
def get_context(self, node):
return node.inputs[0].type.context
def c_code_cache_version(self): def c_code_cache_version(self):
return (11,) return (11,)
...@@ -56,7 +61,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -56,7 +61,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
kname = "k_multi_warp_less" kname = "k_multi_warp_less"
k_var = "k_multi_warp_less_" + nodename k_var = "k_multi_warp_less_" + nodename
code = """ code = """
//a version that use less register but don't work in all case. // a version that uses less registers but doesn't work in all cases.
KERNEL void %(kname)s( KERNEL void %(kname)s(
const int nb_batch, const int nb_batch,
const int nb_stack, const int nb_stack,
...@@ -233,6 +238,8 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -233,6 +238,8 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
return kernels return kernels
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
if node.inputs[0].type.context.kind != 'cuda':
raise NotImplementedError("cuda only")
dtype_ten4 = node.inputs[0].dtype dtype_ten4 = node.inputs[0].dtype
dtype_neib_shape = node.inputs[1].dtype dtype_neib_shape = node.inputs[1].dtype
dtype_neib_step = node.inputs[2].dtype dtype_neib_step = node.inputs[2].dtype
...@@ -243,6 +250,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -243,6 +250,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
ten4, neib_shape, neib_step = inp ten4, neib_shape, neib_step = inp
z, = out z, = out
fail = sub['fail'] fail = sub['fail']
ctx = sub['context']
mode = self.mode mode = self.mode
err_check = """ err_check = """
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
...@@ -369,8 +377,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -369,8 +377,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
dims[0] = z_dim0; dims[0] = z_dim0;
dims[1] = z_dim1; dims[1] = z_dim1;
%(z)s = pygpu_empty(2, dims, %(typecode_z)s, %(z)s = pygpu_empty(2, dims, %(typecode_z)s,
GA_C_ORDER, pygpu_default_context(), GA_C_ORDER, %(ctx)s, Py_None);
Py_None);
if (!%(z)s) if (!%(z)s)
{ {
PyErr_SetString(PyExc_MemoryError, "GpuImages2Neibs:" PyErr_SetString(PyExc_MemoryError, "GpuImages2Neibs:"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论