提交 ece53783 authored 作者: Philippe Hamel's avatar Philippe Hamel

Added an ignore_border mode for the Image2Neibs Op

上级 927877c8
...@@ -12,8 +12,8 @@ if cuda_available: ...@@ -12,8 +12,8 @@ if cuda_available:
class Images2Neibs(Op): class Images2Neibs(Op):
def __init__(self, mode='valid'): def __init__(self, mode='valid'):
if mode not in ['valid','wrap_centered']: if mode not in ['valid','wrap_centered','ignore_borders']:
raise NotImplementedError("Only the mode valid and wrap_centered have been implemented for the op Images2Neibs") raise NotImplementedError("Only the mode valid, ignore_borders and wrap_centered have been implemented for the op Images2Neibs")
self.mode = mode self.mode = mode
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.mode==other.mode return type(self) == type(other) and self.mode==other.mode
...@@ -49,13 +49,13 @@ class Images2Neibs(Op): ...@@ -49,13 +49,13 @@ class Images2Neibs(Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, neib_shape, neib_step = inp x, neib_shape, neib_step = inp
gz, = grads gz, = grads
if self.mode=='valid': if self.mode in ['valid','ignore_borders']:
return [neibs2images(gz, neib_shape, x.shape), None, None] return [neibs2images(gz, neib_shape, x.shape, mode=self.mode), None, None]
else: else:
raise NotImplementedError() raise NotImplementedError()
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
ten4, neib_shape, neib_step = inp ten4, neib_shape, neib_step = inp
...@@ -130,6 +130,9 @@ class Images2Neibs(Op): ...@@ -130,6 +130,9 @@ class Images2Neibs(Op):
} }
grid_c = 1+(((%(ten4)s->dimensions)[2]-c)/step_x); //number of patch in height grid_c = 1+(((%(ten4)s->dimensions)[2]-c)/step_x); //number of patch in height
grid_d = 1+(((%(ten4)s->dimensions)[3]-d)/step_y); //number of patch in width grid_d = 1+(((%(ten4)s->dimensions)[3]-d)/step_y); //number of patch in width
}else if ( "%(mode)s" == "ignore_borders") {
grid_c = 1+(((%(ten4)s->dimensions)[2]-c)/step_x); //number of patch in height
grid_d = 1+(((%(ten4)s->dimensions)[3]-d)/step_y); //number of patch in width
}else{ }else{
PyErr_Format(PyExc_TypeError, "Images2Neibs: unknow mode '%(mode)s'"); PyErr_Format(PyExc_TypeError, "Images2Neibs: unknow mode '%(mode)s'");
%(fail)s; %(fail)s;
...@@ -221,7 +224,7 @@ class Images2Neibs(Op): ...@@ -221,7 +224,7 @@ class Images2Neibs(Op):
def images2neibs(ten4, neib_shape, neib_step=None, mode='valid'): def images2neibs(ten4, neib_shape, neib_step=None, mode='valid'):
return Images2Neibs(mode)(ten4, neib_shape, neib_step) return Images2Neibs(mode)(ten4, neib_shape, neib_step)
def neibs2images(neibs, neib_shape, original_shape): def neibs2images(neibs, neib_shape, original_shape, mode='valid'):
""" """
Inverse of images2neib. Inverse of images2neib.
...@@ -236,8 +239,22 @@ def neibs2images(neibs, neib_shape, original_shape): ...@@ -236,8 +239,22 @@ def neibs2images(neibs, neib_shape, original_shape):
original_shape = T.as_tensor_variable(original_shape) original_shape = T.as_tensor_variable(original_shape)
new_neib_shape = T.stack( original_shape[-1]/neib_shape[1], neib_shape[1] ) new_neib_shape = T.stack( original_shape[-1]/neib_shape[1], neib_shape[1] )
return images2neibs(neibs.dimshuffle('x','x',0,1), new_neib_shape).reshape(original_shape) output_2d = images2neibs(neibs.dimshuffle('x','x',0,1), new_neib_shape, mode=mode)
#return images2neibs(neibs.reshape((1,1,neibs.shape[0],neibs.shape[1])), new_neib_shape).reshape(original_shape)
if mode == 'ignore_borders':
valid_shape = list(original_shape)
valid_shape[2] = valid_shape[2] / neib_shape[0] * neib_shape[0]
valid_shape[3] = valid_shape[3] / neib_shape[1] * neib_shape[1]
output_4d = output_2d.reshape(valid_shape)
#padding the borders with zeros
for d in [2,3]:
pad_shape = list(output_4d.shape)
pad_shape[d] = original_shape[d] - valid_shape[d]
output_4d = T.concatenate([output_4d,T.zeros(pad_shape)],axis=d)
else:
output_4d = output_2d.reshape(original_shape)
return output_4d
# This is work in progress # This is work in progress
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论