提交 b21945df authored 作者: Frederic's avatar Frederic

Implement GpuImages2Neibs mode ignore_border.

上级 4e2127a4
...@@ -13,8 +13,9 @@ if cuda_available: ...@@ -13,8 +13,9 @@ if cuda_available:
class GpuImages2Neibs(Images2Neibs, GpuOp): class GpuImages2Neibs(Images2Neibs, GpuOp):
def __init__(self, mode='valid'): def __init__(self, mode='valid'):
if mode not in ['valid', 'wrap_centered']: if mode not in ['valid', 'ignore_borders', 'wrap_centered']:
raise NotImplementedError("Only the mode valid and wrap_centered" raise NotImplementedError("Only the mode valid, ignore_borders"
" and wrap_centered"
" have been implemented for the op" " have been implemented for the op"
" GpuImages2Neibs") " GpuImages2Neibs")
self.mode = mode self.mode = mode
...@@ -277,6 +278,11 @@ class GpuImages2Neibs(Images2Neibs, GpuOp): ...@@ -277,6 +278,11 @@ class GpuImages2Neibs(Images2Neibs, GpuOp):
grid_c = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[2]-c)/step_x); grid_c = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width //number of patch in width
grid_d = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[3]-d)/step_y); grid_d = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[3]-d)/step_y);
}else if ( "%(mode)s" == "ignore_borders") {
//number of patch in height
grid_c = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width
grid_d = 1+(((CudaNdarray_HOST_DIMS(%(ten4)s))[3]-d)/step_y);
}else{ }else{
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"Images2Neibs: unknow mode '%(mode)s'"); "Images2Neibs: unknow mode '%(mode)s'");
...@@ -403,7 +409,8 @@ def gpu_images2neibs(ten4, neib_shape, neib_step=None, mode='valid'): ...@@ -403,7 +409,8 @@ def gpu_images2neibs(ten4, neib_shape, neib_step=None, mode='valid'):
def use_gpu_images2neibs(node): def use_gpu_images2neibs(node):
if (type(node.op) is Images2Neibs and if (type(node.op) is Images2Neibs and
node.inputs[0].dtype == 'float32' and node.inputs[0].dtype == 'float32' and
node.op.mode in ['valid', 'wrap_centered']): node.op.mode in ['valid', 'ignore_borders',
'wrap_centered']):
return [host_from_gpu(gpu_images2neibs(gpu_from_host(node.inputs[0]), return [host_from_gpu(gpu_images2neibs(gpu_from_host(node.inputs[0]),
node.inputs[1], node.inputs[2], node.inputs[1], node.inputs[2],
mode=node.op.mode))] mode=node.op.mode))]
......
...@@ -41,9 +41,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -41,9 +41,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
g = function([], g = function([],
neibs2images(neibs, neib_shape, images.shape), neibs2images(neibs, neib_shape, images.shape),
mode=self.mode) mode=self.mode)
if border in ['valid']: assert any([isinstance(node.op, self.op)
assert any([isinstance(node.op, self.op) for node in f.maker.fgraph.toposort()])
for node in f.maker.fgraph.toposort()])
#print g() #print g()
assert numpy.allclose(images.get_value(borrow=True), g()) assert numpy.allclose(images.get_value(borrow=True), g())
...@@ -59,6 +58,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -59,6 +58,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
for border in ['valid', 'ignore_borders']: for border in ['valid', 'ignore_borders']:
f = function([], images2neibs(images, neib_shape, mode=border), f = function([], images2neibs(images, neib_shape, mode=border),
mode=self.mode) mode=self.mode)
assert any([isinstance(node.op, self.op)
for node in f.maker.fgraph.toposort()])
#print images.get_value(borrow=True) #print images.get_value(borrow=True)
neibs = f() neibs = f()
...@@ -107,9 +108,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -107,9 +108,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
mode=self.mode) mode=self.mode)
neibs = f() neibs = f()
if border in ['valid']: assert self.op in [type(node.op)
assert self.op in [type(node.op) for node in f.maker.fgraph.toposort()]
for node in f.maker.fgraph.toposort()]
assert numpy.allclose(neibs, assert numpy.allclose(neibs,
[[ 0, 1, 2, 5, 6, 7, 10, 11, 12], [[ 0, 1, 2, 5, 6, 7, 10, 11, 12],
...@@ -162,6 +162,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -162,6 +162,8 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
images2neibs(images, neib_shape, images2neibs(images, neib_shape,
mode='ignore_borders'), mode='ignore_borders'),
mode=self.mode) mode=self.mode)
assert self.op in [type(node.op)
for node in f.maker.fgraph.toposort()]
f() f()
def test_neibs_wrap_centered_step_manual(self): def test_neibs_wrap_centered_step_manual(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论