提交 3ef165c4 authored 作者: Li's avatar Li

some modification

上级 a458f3a9
...@@ -85,9 +85,6 @@ class Images2Neibs(Op): ...@@ -85,9 +85,6 @@ class Images2Neibs(Op):
assert neib_shape.ndim == 1 assert neib_shape.ndim == 1
assert neib_step.ndim == 1 assert neib_step.ndim == 1
self.neib_shape = neib_shape
self.neib_step = neib_step
return Apply(self, [ten4, neib_shape, neib_step], return Apply(self, [ten4, neib_shape, neib_step],
[T.matrix(dtype=ten4.type.dtype)]) [T.matrix(dtype=ten4.type.dtype)])
...@@ -219,8 +216,8 @@ class Images2Neibs(Op): ...@@ -219,8 +216,8 @@ class Images2Neibs(Op):
return (a // b) + 1 return (a // b) + 1
else: else:
return a // b return a // b
c, d = self.neib_shape c, d = node.inputs[1]
step_x, step_y = self.neib_step step_x, step_y = node.inputs[2]
if self.mode == 'wrap_centered': if self.mode == 'wrap_centered':
grid_c = CEIL_INTDIV(in_shape[2], step_x) grid_c = CEIL_INTDIV(in_shape[2], step_x)
grid_d = CEIL_INTDIV(in_shape[3], step_y) grid_d = CEIL_INTDIV(in_shape[3], step_y)
......
...@@ -377,9 +377,37 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -377,9 +377,37 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
[images], [images],
Images2Neibs Images2Neibs
) )
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,3),
mode='valid')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 5)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,2),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 5, 5) shape = (100, 40, 5, 5)
images = numpy.ones(shape).astype('float32') images = numpy.ones(shape).astype('float32')
x = T.ftensor4() x = T.ftensor4()
f = self._compile_and_check([x],
[images2neibs(
x, neib_shape=(2,3),
mode='ignore_borders')],
[images],
Images2Neibs
)
shape = (100, 40, 6, 6)
images = numpy.ones(shape).astype('float32')
x = T.ftensor4()
f = self._compile_and_check([x], f = self._compile_and_check([x],
[images2neibs( [images2neibs(
x, neib_shape=(2,2), x, neib_shape=(2,2),
...@@ -397,6 +425,7 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -397,6 +425,7 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
[images], [images],
Images2Neibs Images2Neibs
) )
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论