提交 4c76812b authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2430 from yaoli/image2neibs_infer_shape

Image2neibs infer shape
...@@ -208,7 +208,24 @@ class Images2Neibs(Op): ...@@ -208,7 +208,24 @@ class Images2Neibs(Op):
z_col = j + d * i z_col = j + d * i
z[0][z_row, z_col] = ten4[n, s, ten4_2, ten4_3] z[0][z_row, z_col] = ten4[n, s, ten4_2, ten4_3]
def infer_shape(self, node, input_shape):
in_shape = input_shape[0]
c, d = node.inputs[1]
step_x, step_y = node.inputs[2]
if self.mode == 'wrap_centered':
grid_c = T.ceil_intdiv(in_shape[2], step_x)
grid_d = T.ceil_intdiv(in_shape[3], step_y)
elif self.mode == 'valid':
grid_c = 1 + ((in_shape[2] - c) // step_x)
grid_d = 1 + ((in_shape[3] - d) // step_y)
elif self.mode == 'ignore_borders':
grid_c = 1 + ((in_shape[2] - c) // step_x)
grid_d = 1 + ((in_shape[3] - d) // step_y)
z_dim0 = grid_c * grid_d * in_shape[1] * in_shape[0]
z_dim1 = c * d
return [(z_dim0, z_dim1)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
ten4, neib_shape, neib_step = inp ten4, neib_shape, neib_step = inp
z, = out z, = out
......
...@@ -366,6 +366,66 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -366,6 +366,66 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
for i in range(1000): for i in range(1000):
f() f()
def test_infer_shape(self):
shape = (100, 40, 6, 3)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,1),
mode='valid')],
[images],
Images2Neibs
)
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,3),
mode='valid')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 4)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,1),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 3)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,3),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 6, 7)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,2),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 10)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(3,3),
mode='wrap_centered')],
[images],
Images2Neibs
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论