提交 a458f3a9 authored 作者: Li's avatar Li

added infer_shape for Image2Neibs op

上级 0994034c
......@@ -80,11 +80,14 @@ class Images2Neibs(Op):
neib_step = neib_shape
else:
neib_step = T.as_tensor_variable(neib_step)
assert ten4.ndim == 4
assert neib_shape.ndim == 1
assert neib_step.ndim == 1
self.neib_shape = neib_shape
self.neib_step = neib_step
return Apply(self, [ten4, neib_shape, neib_step],
[T.matrix(dtype=ten4.type.dtype)])
......@@ -208,7 +211,29 @@ class Images2Neibs(Op):
z_col = j + d * i
z[0][z_row, z_col] = ten4[n, s, ten4_2, ten4_3]
def infer_shape(self, node, input_shape):
in_shape = input_shape[0]
def CEIL_INTDIV(a, b):
if a % b:
return (a // b) + 1
else:
return a // b
c, d = self.neib_shape
step_x, step_y = self.neib_step
if self.mode == 'wrap_centered':
grid_c = CEIL_INTDIV(in_shape[2], step_x)
grid_d = CEIL_INTDIV(in_shape[3], step_y)
elif self.mode == 'valid':
grid_c = 1 + ((in_shape[2] - c) // step_x)
grid_d = 1 + ((in_shape[3] - d) // step_y)
elif self.mode == 'ignore_borders':
grid_c = 1 + ((in_shape[2] - c) // step_x)
grid_d = 1 + ((in_shape[3] - d) // step_y)
z_dim0 = grid_c * grid_d * in_shape[1] * in_shape[0]
z_dim1 = c * d
return [(z_dim0, z_dim1)]
def c_code(self, node, name, inp, out, sub):
ten4, neib_shape, neib_step = inp
z, = out
......
......@@ -366,6 +366,37 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
for i in range(1000):
f()
def test_infer_shape(self):
shape = (100, 40, 6, 6)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,2),
mode='valid')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 5)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,2),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 5)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(3,3),
mode='wrap_centered')],
[images],
Images2Neibs
)
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论