提交 67acce06 authored 作者: Frederic's avatar Frederic

small code simplification

上级 3044beb7
......@@ -579,9 +579,8 @@ class GpuCorrMM(GpuOp):
return ['cuda_ndarray.cuh', '<stdio.h>']
def c_code_cache_version(self):
return
# raise this whenever modifying any of the support_code_files
return (0, 21)
return (0, 22)
def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
......@@ -596,14 +595,18 @@ class GpuCorrMM(GpuOp):
out, = out_
dx = self.subsample[0]
dy = self.subsample[1]
border_mode = self.border_mode
sub = sub.copy()
pad = self.pad
if self.border_mode == "valid":
bmode = 1
else:
assert self.border_mode == "full"
bmode = 0
sub.update(locals())
return """
//Mandatory args
const char *mode_str = "%(border_mode)s";
int mode = %(bmode)s;
//Optional args
int dx = %(dx)s;
......@@ -612,21 +615,6 @@ class GpuCorrMM(GpuOp):
CudaNdarray * img = %(img)s;
CudaNdarray * kern = %(kern)s;
CudaNdarray * out2 = NULL;
int mode;
if (strcmp(mode_str, "full") == 0)
{
mode = 0;
}
else if (strcmp(mode_str, "valid") == 0)
{
mode = 1;
}
else
{
PyErr_SetString(PyExc_ValueError,
"mode must be one of 'full' or 'valid'");
%(fail)s;
}
//TODO: Send self.pad, stride, etc
int out_dim[4];
......@@ -844,7 +832,7 @@ class GpuConv(GpuOp):
def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files
return (0, 21)
return (0, 22)
def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
......@@ -859,11 +847,15 @@ class GpuConv(GpuOp):
out, = out_
dx = self.subsample[0]
dy = self.subsample[1]
border_mode = self.border_mode
version = self.version
verbose = self.verbose
sub = sub.copy()
max_threads_dim0 = self.max_threads_dim0
if self.border_mode == "valid":
bmode = 1
else:
assert self.border_mode == "full"
bmode = 0
if max_threads_dim0 is None:
raise NotImplementedError("GpuConv.c_code should not be called "
"directly. It should be called by "
......@@ -872,7 +864,7 @@ class GpuConv(GpuOp):
sub.update(locals())
return """
//Mandatory args
const char *mode_str = "%(border_mode)s";
int mode = %(bmode)s;
//Optional args
int version = %(version)s;
......@@ -880,21 +872,6 @@ class GpuConv(GpuOp):
int dx = %(dx)s;
int dy = %(dy)s;
int mode;
if (strcmp(mode_str, "full") == 0)
{
mode = ConvMode_FULL;
}
else if (strcmp(mode_str, "valid") == 0)
{
mode = ConvMode_VALID;
}
else
{
PyErr_SetString(PyExc_ValueError,
"mode must be one of 'full' or 'valid'");
return NULL;
}
// TODO, make out be decref before we alloc out2!
CudaNdarray * out2 = (CudaNdarray *)CudaNdarray_Conv(%(img)s, %(kern)s,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论