提交 1433cacb authored 作者: notoraptor's avatar notoraptor

Op param for theano.tensor.nnet.neighbours.Images2Neibs:

- mode (enum list)
上级 8e3ffa84
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import theano import theano
from theano import Op, Apply from theano import Op, Apply
from theano.gof import EnumList
import theano.tensor as T import theano.tensor as T
from theano.gradient import grad_not_implemented from theano.gradient import grad_not_implemented
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
...@@ -39,13 +40,20 @@ class Images2Neibs(Op): ...@@ -39,13 +40,20 @@ class Images2Neibs(Op):
""" """
__props__ = ("mode",) __props__ = ("mode",)
params_type = EnumList(('MODE_VALID', 'valid'),
('MODE_HALF', 'half'),
('MODE_FULL', 'full'),
('MODE_WRAP_CENTERED', 'wrap_centered'),
('MODE_IGNORE_BORDERS', 'ignore_borders'))
def get_params(self, node):
return self.mode
def __init__(self, mode='valid'): def __init__(self, mode='valid'):
if mode not in ['valid', 'half', 'full', implemented_modes = self.params_type.get_aliases()
'wrap_centered', 'ignore_borders']: if mode not in implemented_modes:
raise NotImplementedError("Only the mode valid, half, full, " raise NotImplementedError("Only modes %s have been implemented for Images2Neibs"
"ignore_borders and wrap_centered have " % ', '.join(implemented_modes))
"been implemented for Images2Neibs")
self.mode = mode self.mode = mode
def __str__(self): def __str__(self):
...@@ -159,9 +167,9 @@ class Images2Neibs(Op): ...@@ -159,9 +167,9 @@ class Images2Neibs(Op):
grad_undefined(self, 2, neib_step)] grad_undefined(self, 2, neib_step)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
def perform(self, node, inp, out_): def perform(self, node, inp, out_, params):
ten4, neib_shape, neib_step = inp ten4, neib_shape, neib_step = inp
z, = out_ z, = out_
# GpuImages2Neibs should not run this perform in DebugMode # GpuImages2Neibs should not run this perform in DebugMode
...@@ -344,11 +352,6 @@ class Images2Neibs(Op): ...@@ -344,11 +352,6 @@ class Images2Neibs(Op):
return [(z_dim0, z_dim1)] return [(z_dim0, z_dim1)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
ten4, neib_shape, neib_step = inp
z, = out
fail = sub['fail']
mode = self.mode
return """ return """
#ifndef CEIL_INTDIV #ifndef CEIL_INTDIV
#define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0)) #define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0))
...@@ -408,7 +411,7 @@ class Images2Neibs(Op): ...@@ -408,7 +411,7 @@ class Images2Neibs(Op):
%(fail)s; %(fail)s;
} }
if ( "%(mode)s" == "wrap_centered") { if (%(mode)s == MODE_WRAP_CENTERED) {
if (c%%2!=1 || d%%2!=1){ if (c%%2!=1 || d%%2!=1){
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"Images2Neibs: in mode wrap_centered" "Images2Neibs: in mode wrap_centered"
...@@ -430,7 +433,7 @@ class Images2Neibs(Op): ...@@ -430,7 +433,7 @@ class Images2Neibs(Op):
grid_c = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[2]),step_x); grid_c = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[2]),step_x);
grid_d = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[3]),step_y); grid_d = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[3]),step_y);
}else if ( "%(mode)s" == "valid") { } else if (%(mode)s == MODE_VALID) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) || if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]-c) %% step_x)!=0)) ( (((PyArray_DIMS(%(ten4)s))[2]-c) %% step_x)!=0))
{ {
...@@ -455,12 +458,12 @@ class Images2Neibs(Op): ...@@ -455,12 +458,12 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x); grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width //number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y); grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
}else if ( "%(mode)s" == "ignore_borders") { } else if (%(mode)s == MODE_IGNORE_BORDERS) {
//number of patch in height //number of patch in height
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x); grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width //number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y); grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
}else if ( "%(mode)s" == "half") { } else if (%(mode)s == MODE_HALF) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) || if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]-(c%%2)) %% step_x)!=0)) ( (((PyArray_DIMS(%(ten4)s))[2]-(c%%2)) %% step_x)!=0))
{ {
...@@ -485,7 +488,7 @@ class Images2Neibs(Op): ...@@ -485,7 +488,7 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-(c%%2))/step_x); grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-(c%%2))/step_x);
//number of patch in width //number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-(d%%2))/step_y); grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-(d%%2))/step_y);
}else if ( "%(mode)s" == "full") { } else if (%(mode)s == MODE_FULL) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) || if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]+c-2) %% step_x)!=0)) ( (((PyArray_DIMS(%(ten4)s))[2]+c-2) %% step_x)!=0))
{ {
...@@ -510,9 +513,9 @@ class Images2Neibs(Op): ...@@ -510,9 +513,9 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]+c-2)/step_x); grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]+c-2)/step_x);
//number of patch in width //number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]+d-2)/step_y); grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]+d-2)/step_y);
}else { } else {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"Images2Neibs: unknow mode '%(mode)s'"); "Images2Neibs: unknow mode %%d", %(mode)s);
%(fail)s; %(fail)s;
} }
...@@ -572,7 +575,7 @@ class Images2Neibs(Op): ...@@ -572,7 +575,7 @@ class Images2Neibs(Op):
for (int i = 0; i < c; i++) // loop over c for (int i = 0; i < c; i++) // loop over c
{ {
int ten4_2 = i + a * step_x; int ten4_2 = i + a * step_x;
if ( "%(mode)s" == "wrap_centered" ){ if (%(mode)s == MODE_WRAP_CENTERED) {
ten4_2 -= wrap_centered_half_idx_shift_x; ten4_2 -= wrap_centered_half_idx_shift_x;
if ( ten4_2 < 0 ) ten4_2 += height; if ( ten4_2 < 0 ) ten4_2 += height;
else if (ten4_2 >= height) ten4_2 -= height; else if (ten4_2 >= height) ten4_2 -= height;
...@@ -588,13 +591,13 @@ class Images2Neibs(Op): ...@@ -588,13 +591,13 @@ class Images2Neibs(Op):
for (int j = 0; j < d; j++) // loop over d for (int j = 0; j < d; j++) // loop over d
{ {
int ten4_3 = j + b * step_y; int ten4_3 = j + b * step_y;
if ( "%(mode)s" == "wrap_centered" ){ if (%(mode)s == MODE_WRAP_CENTERED) {
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
if ( ten4_3 < 0 ) ten4_3 += width; if ( ten4_3 < 0 ) ten4_3 += width;
else if (ten4_3 >= width) ten4_3 -= width; else if (ten4_3 >= width) ten4_3 -= width;
} else if ( "%(mode)s" == "half" ){ } else if (%(mode)s == MODE_HALF) {
ten4_3 -= wrap_centered_half_idx_shift_y; ten4_3 -= wrap_centered_half_idx_shift_y;
} else if ( "%(mode)s" == "full" ){ } else if (%(mode)s == MODE_FULL) {
ten4_3 -= d - 1; ten4_3 -= d - 1;
} }
int z_col = j + d * i; int z_col = j + d * i;
...@@ -609,7 +612,8 @@ class Images2Neibs(Op): ...@@ -609,7 +612,8 @@ class Images2Neibs(Op):
} }
} }
} // END NESTED SCOPE } // END NESTED SCOPE
""" % locals() """ % dict(ten4=inp[0], neib_shape=inp[1], neib_step=inp[2], z=out[0],
fail=sub['fail'], mode=sub['params'])
def images2neibs(ten4, neib_shape, neib_step=None, mode='valid'): def images2neibs(ten4, neib_shape, neib_step=None, mode='valid'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论