提交 1433cacb authored 作者: notoraptor's avatar notoraptor

Op param for theano.tensor.nnet.neighbours.Images2Neibs:

- mode (enum list)
上级 8e3ffa84
......@@ -8,6 +8,7 @@ import numpy as np
import theano
from theano import Op, Apply
from theano.gof import EnumList
import theano.tensor as T
from theano.gradient import grad_not_implemented
from theano.gradient import grad_undefined
......@@ -39,13 +40,20 @@ class Images2Neibs(Op):
"""
__props__ = ("mode",)
params_type = EnumList(('MODE_VALID', 'valid'),
('MODE_HALF', 'half'),
('MODE_FULL', 'full'),
('MODE_WRAP_CENTERED', 'wrap_centered'),
('MODE_IGNORE_BORDERS', 'ignore_borders'))
def get_params(self, node):
return self.mode
def __init__(self, mode='valid'):
if mode not in ['valid', 'half', 'full',
'wrap_centered', 'ignore_borders']:
raise NotImplementedError("Only the mode valid, half, full, "
"ignore_borders and wrap_centered have "
"been implemented for Images2Neibs")
implemented_modes = self.params_type.get_aliases()
if mode not in implemented_modes:
raise NotImplementedError("Only modes %s have been implemented for Images2Neibs"
% ', '.join(implemented_modes))
self.mode = mode
def __str__(self):
......@@ -159,9 +167,9 @@ class Images2Neibs(Op):
grad_undefined(self, 2, neib_step)]
def c_code_cache_version(self):
return (8,)
return (9,)
def perform(self, node, inp, out_):
def perform(self, node, inp, out_, params):
ten4, neib_shape, neib_step = inp
z, = out_
# GpuImages2Neibs should not run this perform in DebugMode
......@@ -344,11 +352,6 @@ class Images2Neibs(Op):
return [(z_dim0, z_dim1)]
def c_code(self, node, name, inp, out, sub):
ten4, neib_shape, neib_step = inp
z, = out
fail = sub['fail']
mode = self.mode
return """
#ifndef CEIL_INTDIV
#define CEIL_INTDIV(a, b) ((a/b) + ((a %% b) ? 1: 0))
......@@ -408,7 +411,7 @@ class Images2Neibs(Op):
%(fail)s;
}
if ( "%(mode)s" == "wrap_centered") {
if (%(mode)s == MODE_WRAP_CENTERED) {
if (c%%2!=1 || d%%2!=1){
PyErr_Format(PyExc_TypeError,
"Images2Neibs: in mode wrap_centered"
......@@ -430,7 +433,7 @@ class Images2Neibs(Op):
grid_c = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[2]),step_x);
grid_d = CEIL_INTDIV(((PyArray_DIMS(%(ten4)s))[3]),step_y);
}else if ( "%(mode)s" == "valid") {
} else if (%(mode)s == MODE_VALID) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]-c) %% step_x)!=0))
{
......@@ -455,12 +458,12 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
}else if ( "%(mode)s" == "ignore_borders") {
} else if (%(mode)s == MODE_IGNORE_BORDERS) {
//number of patch in height
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-c)/step_x);
//number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-d)/step_y);
}else if ( "%(mode)s" == "half") {
} else if (%(mode)s == MODE_HALF) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]-(c%%2)) %% step_x)!=0))
{
......@@ -485,7 +488,7 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]-(c%%2))/step_x);
//number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]-(d%%2))/step_y);
}else if ( "%(mode)s" == "full") {
} else if (%(mode)s == MODE_FULL) {
if ( ((PyArray_DIMS(%(ten4)s))[2] < c) ||
( (((PyArray_DIMS(%(ten4)s))[2]+c-2) %% step_x)!=0))
{
......@@ -510,9 +513,9 @@ class Images2Neibs(Op):
grid_c = 1+(((PyArray_DIMS(%(ten4)s))[2]+c-2)/step_x);
//number of patch in width
grid_d = 1+(((PyArray_DIMS(%(ten4)s))[3]+d-2)/step_y);
}else {
} else {
PyErr_Format(PyExc_TypeError,
"Images2Neibs: unknow mode '%(mode)s'");
"Images2Neibs: unknow mode %%d", %(mode)s);
%(fail)s;
}
......@@ -572,7 +575,7 @@ class Images2Neibs(Op):
for (int i = 0; i < c; i++) // loop over c
{
int ten4_2 = i + a * step_x;
if ( "%(mode)s" == "wrap_centered" ){
if (%(mode)s == MODE_WRAP_CENTERED) {
ten4_2 -= wrap_centered_half_idx_shift_x;
if ( ten4_2 < 0 ) ten4_2 += height;
else if (ten4_2 >= height) ten4_2 -= height;
......@@ -588,13 +591,13 @@ class Images2Neibs(Op):
for (int j = 0; j < d; j++) // loop over d
{
int ten4_3 = j + b * step_y;
if ( "%(mode)s" == "wrap_centered" ){
if (%(mode)s == MODE_WRAP_CENTERED) {
ten4_3 -= wrap_centered_half_idx_shift_y;
if ( ten4_3 < 0 ) ten4_3 += width;
else if (ten4_3 >= width) ten4_3 -= width;
} else if ( "%(mode)s" == "half" ){
} else if (%(mode)s == MODE_HALF) {
ten4_3 -= wrap_centered_half_idx_shift_y;
} else if ( "%(mode)s" == "full" ){
} else if (%(mode)s == MODE_FULL) {
ten4_3 -= d - 1;
}
int z_col = j + d * i;
......@@ -609,7 +612,8 @@ class Images2Neibs(Op):
}
}
} // END NESTED SCOPE
""" % locals()
""" % dict(ten4=inp[0], neib_shape=inp[1], neib_step=inp[2], z=out[0],
fail=sub['fail'], mode=sub['params'])
def images2neibs(ten4, neib_shape, neib_step=None, mode='valid'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论