提交 84655e53 authored 作者: Frederic's avatar Frederic

implement the missing infer_shape

上级 f8ec21d6
......@@ -455,41 +455,11 @@ class GpuDnnConvGradW(DnnBase, COp):
[CudaNdarrayType(broadcastable)()])
def infer_shape(self, node, shape):
h = shape[0][2] # Height of input feature maps
w = shape[0][3] # Width of input feature maps
kh = shape[1][2] # Height of each filter
kw = shape[1][3] # Width of each filter
out3 = kh
out4 = kw
desc = node.inputs[2].owner.op
sh, sw = desc.subsample
# We don't have the information necessary, namely the weight size so
# we cannot infer the shape
if sh != 1 or sw != 1:
raise ShapeError(
'Unable to infer shape for stride (%d, %d)' % (sh, sw)
)
if desc.border_mode == 'full':
out3 = 2 - h + (kh - 1) * sh
out4 = 2 - w + (kw - 1) * sw
elif desc.border_mode == 'valid':
out3 = h - (kh - 1) * sh
out4 = w - (kw - 1) * sw
else:
assert isinstance(desc.border_mode, tuple)
assert len(desc.border_mode) == 2
assert isinstance(desc.border_mode[0], int)
assert isinstance(desc.border_mode[1], int)
raise ShapeError('Not implemented')
return [(
shape[1][1],
shape[0][1],
out3,
out4
node.inputs[3],
node.inputs[4]
)]
......@@ -547,41 +517,11 @@ class GpuDnnConvGradI(DnnBase, COp):
[CudaNdarrayType(broadcastable)()])
def infer_shape(self, node, shape):
padh = 0
padw = 0
desc = node.inputs[2].owner.op
sh, sw = desc.subsample
# We don't have the information necessary, namely the image size so
# we cannot infer the shape
if sh != 1 or sw != 1:
raise ShapeError(
'Unable to infer shape for stride (%d, %d)' % (sh, sw)
)
if desc.border_mode == 'full':
padh = shape[0][2] - 1
padw = shape[0][3] - 1
elif isinstance(desc.border_mode, tuple):
padh, padw = desc.border_mode
elif desc.border_mode == 'valid':
pass
else:
assert isinstance(desc.border_mode, tuple)
assert len(desc.border_mode) == 2
assert isinstance(desc.border_mode[0], int)
assert isinstance(desc.border_mode[1], int)
raise ShapeError('Not implemented')
out2 = (shape[1][2] - 1) * sh + shape[0][2] - 2*padh
out3 = (shape[1][3] - 1) * sw + shape[0][3] - 2*padw
return [(
shape[1][0],
shape[0][1],
out2,
out3
node.inputs[3],
node.inputs[4]
)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论