提交 6c4bdd7d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Type context for neighbours.py.

上级 a28251c6
......@@ -10,7 +10,8 @@ try:
except ImportError:
pass
from .basic_ops import as_gpuarray_variable, GpuKernelBase, Kernel
from .basic_ops import (as_gpuarray_variable, GpuKernelBase, Kernel,
infer_context_name)
from .opt import register_opt as register_gpu_opt, op_lifter
from .type import GpuArrayType
......@@ -25,7 +26,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
self.mode = mode
def make_node(self, ten4, neib_shape, neib_step):
ten4 = as_gpuarray_variable(ten4)
ten4 = as_gpuarray_variable(ten4, infer_context_name(ten4))
neib_shape = T.as_tensor_variable(neib_shape)
neib_step = T.as_tensor_variable(neib_step)
......@@ -37,7 +38,11 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
return Apply(self, [ten4, neib_shape, neib_step],
[GpuArrayType(broadcastable=(False, False),
dtype=ten4.type.dtype)()])
dtype=ten4.type.dtype,
context_name=ten4.type.context_name)()])
def get_context(self, node):
return node.inputs[0].type.context
def c_code_cache_version(self):
return (11,)
......@@ -56,7 +61,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
kname = "k_multi_warp_less"
k_var = "k_multi_warp_less_" + nodename
code = """
//a version that use less register but don't work in all case.
// a version that uses less registers but doesn't work in all cases.
KERNEL void %(kname)s(
const int nb_batch,
const int nb_stack,
......@@ -233,6 +238,8 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
return kernels
def c_code(self, node, name, inp, out, sub):
if node.inputs[0].type.context.kind != 'cuda':
raise NotImplementedError("cuda only")
dtype_ten4 = node.inputs[0].dtype
dtype_neib_shape = node.inputs[1].dtype
dtype_neib_step = node.inputs[2].dtype
......@@ -243,6 +250,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
ten4, neib_shape, neib_step = inp
z, = out
fail = sub['fail']
ctx = sub['context']
mode = self.mode
err_check = """
if (err != GA_NO_ERROR) {
......@@ -369,8 +377,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
dims[0] = z_dim0;
dims[1] = z_dim1;
%(z)s = pygpu_empty(2, dims, %(typecode_z)s,
GA_C_ORDER, pygpu_default_context(),
Py_None);
GA_C_ORDER, %(ctx)s, Py_None);
if (!%(z)s)
{
PyErr_SetString(PyExc_MemoryError, "GpuImages2Neibs:"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论