提交 61cb8c41 authored 作者: notoraptor's avatar notoraptor

Fix forgotten variable in C code string.

Improve usage of op params for neighbours ops (try to fix jenkins error). Update AdvancedIncSubtensor1.c_code_cache_version (as we change the type of an op param).
上级 5f08e4d7
...@@ -21,7 +21,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -21,7 +21,7 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
Images2Neibs for the GPU. Images2Neibs for the GPU.
""" """
params_type = ParamsType(mode=Images2Neibs.params_type, context=gpu_context_type) params_type = ParamsType(mode=Images2Neibs.BORDER_MODE, context=gpu_context_type)
def get_params(self, node): def get_params(self, node):
return self.params_type.get_params(self, context=node.inputs[0].type.context) return self.params_type.get_params(self, context=node.inputs[0].type.context)
...@@ -57,9 +57,8 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -57,9 +57,8 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
flags = Kernel.get_flags(dtype_ten4, dtype_z) flags = Kernel.get_flags(dtype_ten4, dtype_z)
type_ten4 = gpuarray.dtype_to_ctype(dtype_ten4) type_ten4 = gpuarray.dtype_to_ctype(dtype_ten4)
type_z = gpuarray.dtype_to_ctype(dtype_z) type_z = gpuarray.dtype_to_ctype(dtype_z)
# Params type 'mode' is an enum list which c_support_code() # `BORDER_MODE`'s c_support_code() contains C constants definitions that are useful here.
# contains C constants definitions that are useful here. mode_constants = self.BORDER_MODE.c_support_code()
mode_constants = self.params_type.get_type('mode').c_support_code()
kernels = [] kernels = []
kname = "k_multi_warp_less" kname = "k_multi_warp_less"
k_var = "k_multi_warp_less_" + nodename k_var = "k_multi_warp_less_" + nodename
...@@ -110,29 +109,29 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -110,29 +109,29 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
ga_int i = LID_1; // loop over c ga_int i = LID_1; // loop over c
{ {
ga_int ten4_2 = i + a * step_x; ga_int ten4_2 = i + a * step_x;
if(%(mode)s == MODE_WRAP_CENTERED) { if(mode == MODE_WRAP_CENTERED) {
ten4_2 -= wrap_centered_half_idx_shift_x; ten4_2 -= wrap_centered_half_idx_shift_x;
if ( ten4_2 < 0 ) if ( ten4_2 < 0 )
ten4_2 += height; ten4_2 += height;
else if (ten4_2 >= height) else if (ten4_2 >= height)
ten4_2 -= height; ten4_2 -= height;
} else if (%(mode)s == MODE_HALF) { } else if (mode == MODE_HALF) {
ten4_2 -= wrap_centered_half_idx_shift_x; ten4_2 -= wrap_centered_half_idx_shift_x;
} else if (%(mode)s == MODE_FULL) { } else if (mode == MODE_FULL) {
ten4_2 -= c - 1; ten4_2 -= c - 1;
} }
ga_int j = LID_0; // loop over d ga_int j = LID_0; // loop over d
{ {
ga_int ten4_3 = j + b * step_y; ga_int ten4_3 = j + b * step_y;
if(%(mode)s == MODE_WRAP_CENTERED){ if(mode == MODE_WRAP_CENTERED){
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
if ( ten4_3 < 0 ) if ( ten4_3 < 0 )
ten4_3 += width; ten4_3 += width;
else if (ten4_3 >= width) else if (ten4_3 >= width)
ten4_3 -= width; ten4_3 -= width;
} else if (%(mode)s == MODE_HALF) { } else if (mode == MODE_HALF) {
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
} else if (%(mode)s == MODE_FULL) { } else if (mode == MODE_FULL) {
ten4_3 -= d - 1; ten4_3 -= d - 1;
} }
...@@ -212,30 +211,30 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op): ...@@ -212,30 +211,30 @@ class GpuImages2Neibs(GpuKernelBase, Images2Neibs, Op):
for (ga_int i = LID_1; i < c; i+=LDIM_1) for (ga_int i = LID_1; i < c; i+=LDIM_1)
{ {
ga_int ten4_2 = i + a * step_x; ga_int ten4_2 = i + a * step_x;
if(%(mode)s == MODE_WRAP_CENTERED) { if(mode == MODE_WRAP_CENTERED) {
ten4_2 -= wrap_centered_half_idx_shift_x; ten4_2 -= wrap_centered_half_idx_shift_x;
if ( ten4_2 < 0 ) if ( ten4_2 < 0 )
ten4_2 += height; ten4_2 += height;
else if (ten4_2 >= height) else if (ten4_2 >= height)
ten4_2 -= height; ten4_2 -= height;
} else if (%(mode)s == MODE_HALF) { } else if (mode == MODE_HALF) {
ten4_2 -= wrap_centered_half_idx_shift_x; ten4_2 -= wrap_centered_half_idx_shift_x;
} else if (%(mode)s == MODE_FULL) { } else if (mode == MODE_FULL) {
ten4_2 -= c - 1; ten4_2 -= c - 1;
} }
// loop over d // loop over d
for (ga_int j = LID_0; j < d; j+=LDIM_0) for (ga_int j = LID_0; j < d; j+=LDIM_0)
{ {
ga_int ten4_3 = j + b * step_y; ga_int ten4_3 = j + b * step_y;
if(%(mode)s == MODE_WRAP_CENTERED) { if(mode == MODE_WRAP_CENTERED) {
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
if ( ten4_3 < 0 ) if ( ten4_3 < 0 )
ten4_3 += width; ten4_3 += width;
else if (ten4_3 >= width) else if (ten4_3 >= width)
ten4_3 -= width; ten4_3 -= width;
} else if (%(mode)s == MODE_HALF) { } else if (mode == MODE_HALF) {
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
} else if (%(mode)s == MODE_FULL) { } else if (mode == MODE_FULL) {
ten4_3 -= d - 1; ten4_3 -= d - 1;
} }
......
...@@ -40,17 +40,18 @@ class Images2Neibs(Op): ...@@ -40,17 +40,18 @@ class Images2Neibs(Op):
""" """
__props__ = ("mode",) __props__ = ("mode",)
params_type = EnumList(('MODE_VALID', 'valid'), BORDER_MODE = EnumList(('MODE_VALID', 'valid'),
('MODE_HALF', 'half'), ('MODE_HALF', 'half'),
('MODE_FULL', 'full'), ('MODE_FULL', 'full'),
('MODE_WRAP_CENTERED', 'wrap_centered'), ('MODE_WRAP_CENTERED', 'wrap_centered'),
('MODE_IGNORE_BORDERS', 'ignore_borders')) ('MODE_IGNORE_BORDERS', 'ignore_borders'))
params_type = BORDER_MODE
def get_params(self, node): def get_params(self, node):
return self.mode return self.mode
def __init__(self, mode='valid'): def __init__(self, mode='valid'):
implemented_modes = self.params_type.get_aliases() implemented_modes = self.BORDER_MODE.get_aliases()
if mode not in implemented_modes: if mode not in implemented_modes:
raise NotImplementedError("Only modes %s have been implemented for %s" raise NotImplementedError("Only modes %s have been implemented for %s"
% (', '.join(implemented_modes), type(self).__name__)) % (', '.join(implemented_modes), type(self).__name__))
......
...@@ -1989,7 +1989,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -1989,7 +1989,7 @@ class AdvancedIncSubtensor1(Op):
params=sub['params'], fail=sub['fail']) params=sub['params'], fail=sub['fail'])
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def perform(self, node, inp, out_, params): def perform(self, node, inp, out_, params):
# TODO opt to make this inplace # TODO opt to make this inplace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论