提交 b9c51f28 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add infer_shape to abstract convolutions

上级 6d4633be
...@@ -322,6 +322,21 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -322,6 +322,21 @@ class AbstractConv2d(BaseAbstractConv2d):
d_weights = patternbroadcast(d_weights, weights.broadcastable) d_weights = patternbroadcast(d_weights, weights.broadcastable)
return d_bottom, d_weights return d_bottom, d_weights
def infer_shape(self, node, input_shapes):
imshp = input_shapes[0]
kshp = input_shapes[1]
# replace symbolic shapes with known constant shapes
if self.imshp is not None:
imshp = [imshp[i] if self.imshp[i] is None else self.imshp[i]
for i in range(4)]
if self.kshp is not None:
kshp = [kshp[i] if self.kshp[i] is None else self.kshp[i]
for i in range(4)]
res = get_conv_output_shape(imshp, kshp, self.border_mode,
self.subsample)
return [res]
class AbstractConv2d_gradWeights(BaseAbstractConv2d): class AbstractConv2d_gradWeights(BaseAbstractConv2d):
"""Gradient wrt. filters for `AbstractConv2d`. """Gradient wrt. filters for `AbstractConv2d`.
...@@ -387,6 +402,19 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -387,6 +402,19 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
def connection_pattern(self, node): def connection_pattern(self, node):
return [[1], [1], [0]] # no connection to height, width return [[1], [1], [0]] # no connection to height, width
def infer_shape(self, node, input_shapes):
# We use self.kshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the kernel shape
# from the shapes of inputs.
imshp = input_shapes[0]
topshp = input_shapes[1]
kshp = self.kshp[:] if self.kshp is not None else [None] * 4
fallback_kshp = [topshp[1], imshp[1], node.inputs[2][0], node.inputs[2][1]]
kshp = [fallback_kshp[i] if kshp[i] is None else kshp[i]
for i in range(4)]
return [kshp]
class AbstractConv2d_gradInputs(BaseAbstractConv2d): class AbstractConv2d_gradInputs(BaseAbstractConv2d):
"""Gradient wrt. inputs for `AbstractConv2d`. """Gradient wrt. inputs for `AbstractConv2d`.
...@@ -448,3 +476,17 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -448,3 +476,17 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
def connection_pattern(self, node): def connection_pattern(self, node):
return [[1], [1], [0]] # no connection to height, width return [[1], [1], [0]] # no connection to height, width
def infer_shape(self, node, input_shapes):
# We use self.imshp (that was passed when creating the Op) if possible,
# or fall back to the `shape` input of the node.
# TODO: when there is no subsampling, try to infer the image shape
# from the shapes of inputs.
kshp = input_shapes[0]
topshp = input_shapes[1]
imshp = self.imshp[:] if self.imshp is not None else [None] * 4
fallback_imshp = [topshp[0], kshp[1], node.inputs[2][0],
node.inputs[2][1]]
imshp = [fallback_imshp[i] if imshp[i] is None else imshp[i]
for i in range(4)]
return [imshp]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论