提交 fa2207fb authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix when mode is ignore_border

上级 cdfa4964
...@@ -633,9 +633,16 @@ def neibs2images(neibs, neib_shape, original_shape, mode='valid'): ...@@ -633,9 +633,16 @@ def neibs2images(neibs, neib_shape, original_shape, mode='valid'):
new_neib_shape, mode=mode) new_neib_shape, mode=mode)
if mode == 'ignore_borders': if mode == 'ignore_borders':
valid_shape = list(original_shape) # We use set_subtensor to accept original_shape we can't infer
valid_shape[2] = (valid_shape[2] // neib_shape[0]) * neib_shape[0] # the shape and still raise error when it don't have the right
valid_shape[3] = (valid_shape[3] // neib_shape[1]) * neib_shape[1] # shape.
valid_shape = original_shape
valid_shape = T.set_subtensor(
valid_shape[2],
(valid_shape[2] // neib_shape[0]) * neib_shape[0])
valid_shape = T.set_subtensor(
valid_shape[3],
(valid_shape[3] // neib_shape[1]) * neib_shape[1])
output_4d = output_2d.reshape(valid_shape, ndim=4) output_4d = output_2d.reshape(valid_shape, ndim=4)
# padding the borders with zeros # padding the borders with zeros
for d in [2, 3]: for d in [2, 3]:
......
...@@ -352,7 +352,9 @@ class T_Images2Neibs(unittest_tools.InferShapeTester): ...@@ -352,7 +352,9 @@ class T_Images2Neibs(unittest_tools.InferShapeTester):
patsRecovery = T.matrix('patsRecovery') patsRecovery = T.matrix('patsRecovery')
original_size = T.ivector('original_size') original_size = T.ivector('original_size')
out = neibs2images(patsRecovery, (16, 16), original_size) for mode in ['valid', 'ignore_borders']:
out = neibs2images(patsRecovery, (16, 16),
original_size, mode=mode)
f = theano.function([patsRecovery, original_size], out) f = theano.function([patsRecovery, original_size], out)
im_val = numpy.ones((1, 3, 320, 320), dtype=numpy.float32) im_val = numpy.ones((1, 3, 320, 320), dtype=numpy.float32)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论