提交 317389c3 authored 作者: James Bergstra's avatar James Bergstra

many changes to sandbox/downsample

上级 63fb1356
......@@ -22,6 +22,9 @@ class DownsampleFactorMaxGrad(Op):
def __hash__(self):
return hash(type(self)) ^ hash(self.ds) ^ hash(self.ignore_border)
def __str__(self):
return '%s{%s,%s}' % (self.__class__.__name__, self.ds, self.ignore_border)
def make_node(self, x, maxout, gz):
# make_node should only be called by the grad function of DownsampleFactorMax,
# so these asserts should not fail.
......@@ -48,7 +51,7 @@ class DownsampleFactorMaxGrad(Op):
def c_code(self, node, name, (x, z, gz), (gx,), sub):
fail = sub['fail']
self_ignore_border = int(self.ignore_border)
ignore_border = int(self.ignore_border)
ds0, ds1 = self.ds
return """
int x_typenum = PyArray_ObjectType((PyObject*)%(x)s, 0);
......@@ -79,7 +82,7 @@ class DownsampleFactorMaxGrad(Op):
}
z_shp0 = %(z)s->dimensions[2];
z_shp1 = %(z)s->dimensions[3];
if (%(self_ignore_border)s)
if (%(ignore_border)s)
{
x_shp0_usable = z_shp0 * %(ds0)s;
x_shp1_usable = z_shp1 * %(ds1)s;
......@@ -136,6 +139,9 @@ class DownsampleFactorMaxGrad(Op):
}//for b
""" %locals()
def c_code_cache_version(self):
return ()
class DownsampleFactorMax(Op):
"""
......@@ -153,17 +159,20 @@ class DownsampleFactorMax(Op):
rval[2] += 1
if d % ds[1]:
rval[3] += 1
return rval;
return rval
def __init__(self, ds, ignore_border=False):
self.ds = tuple(ds)
self.ignore_border = ignore_border
def __eq__(self, other):
return type(self) == type(other) and self.ds == other.ds
return type(self) == type(other) and self.ds == other.ds and self.ignore_border == other.ignore_border
def __hash__(self):
return hash(type(self)) ^ hash(self.ds)
return hash(type(self)) ^ hash(self.ds) ^ hash(self.ignore_border)
def __str__(self):
return '%s{%s,%s}' % (self.__class__.__name__, self.ds, self.ignore_border)
def make_node(self, x):
dmatrix4 = tensor.TensorType(x.type.dtype, (False, False, False, False))
......@@ -178,6 +187,7 @@ class DownsampleFactorMax(Op):
raise NotImplementedError('DownsampleFactorMax requires 4D input for now')
if z[0] is None:
z[0] = numpy.zeros(self.out_shape(x.shape, self.ds, self.ignore_border)) -float('inf')
z[0] = numpy.asarray(z[0], dtype=x.dtype)
zz=z[0]
ds0, ds1 = self.ds
......@@ -197,7 +207,8 @@ class DownsampleFactorMax(Op):
def c_code(self, node, name, (x,), (z, ), sub):
fail=sub['fail']
self_ignore_border = int(self.ignore_border)
ignore_border = int(self.ignore_border)
print "IGNORE_BORDER", ignore_border
ds0, ds1 = self.ds
return """
int typenum = PyArray_ObjectType((PyObject*)%(x)s, 0);
......@@ -211,7 +222,7 @@ class DownsampleFactorMax(Op):
}
z_shp0 = %(x)s->dimensions[2] / %(ds0)s;
z_shp1 = %(x)s->dimensions[3] / %(ds1)s;
if (%(self_ignore_border)s)
if (%(ignore_border)s)
{
x_shp0_usable = z_shp0 * %(ds0)s;
x_shp1_usable = z_shp1 * %(ds1)s;
......@@ -240,23 +251,29 @@ class DownsampleFactorMax(Op):
%(z)s = (PyArrayObject*) PyArray_ZEROS(4, dims, typenum,0); //TODO: zeros not necessary
}
for(int b=0;b<%(x)s->dimensions[0];b++){
for(int k=0;k<%(x)s->dimensions[1];k++){
int mini_i = 0;
int zi = 0;
for(int i=0;i< x_shp0_usable; i++){
int mini_j = 0;
int zj = 0;
for(int j=0; j<x_shp1_usable; j++){
dtype_%(x)s a = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,i,j)))[0];
dtype_%(z)s * __restrict__ z = ((dtype_%(z)s*)(PyArray_GETPTR4(%(z)s,b,k,zi,zj)));
z[0] = (((mini_j|mini_i) == 0) || z[0] < a) ? a : z[0];
mini_j = ((mini_j + 1) == %(ds1)s) ? 0 : mini_j+1;
zj += (mini_j == 0);
if (z_shp0 && z_shp1)
{
for(int b=0;b<%(x)s->dimensions[0];b++){
for(int k=0;k<%(x)s->dimensions[1];k++){
int mini_i = 0;
int zi = 0;
for(int i=0;i< x_shp0_usable; i++){
int mini_j = 0;
int zj = 0;
for(int j=0; j<x_shp1_usable; j++){
dtype_%(x)s a = ((dtype_%(x)s*)(PyArray_GETPTR4(%(x)s,b,k,i,j)))[0];
dtype_%(z)s * __restrict__ z = ((dtype_%(z)s*)(PyArray_GETPTR4(%(z)s,b,k,zi,zj)));
z[0] = (((mini_j|mini_i) == 0) || z[0] < a) ? a : z[0];
mini_j = ((mini_j + 1) == %(ds1)s) ? 0 : mini_j+1;
zj += (mini_j == 0);
}
mini_i = ((mini_i + 1) == %(ds0)s) ? 0 : mini_i+1;
zi += (mini_i == 0);
}
}
mini_i = ((mini_i + 1) == %(ds0)s) ? 0 : mini_i+1;
zi += (mini_i == 0);
}
}
}
""" % locals()
def c_code_cache_version(self):
return ()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论