提交 c78d5c22 authored 作者: f0k's avatar f0k

Fix ConvOp shape inference and propagation for imshp=None

上级 20e7ec9d
...@@ -373,7 +373,7 @@ class ConvOp(OpenMPOp): ...@@ -373,7 +373,7 @@ class ConvOp(OpenMPOp):
# Expand unknown image / kernel shapes into tuples of Nones # Expand unknown image / kernel shapes into tuples of Nones
if imshp is None: if imshp is None:
imshp = (None, None) imshp = (None, None, None)
else: else:
imshp = tuple(imshp) imshp = tuple(imshp)
if kshp is None: if kshp is None:
...@@ -438,7 +438,7 @@ class ConvOp(OpenMPOp): ...@@ -438,7 +438,7 @@ class ConvOp(OpenMPOp):
else: else:
kshp_logical = tuple(kshp_logical) kshp_logical = tuple(kshp_logical)
if len(kshp_logical) != 2: if len(kshp_logical) != 2:
raise ValueError("len(kshp_logical) must be k, got %d" % len(kshp_logical)) raise ValueError("len(kshp_logical) must be 2, got %d" % len(kshp_logical))
self.kshp_logical = kshp_logical self.kshp_logical = kshp_logical
# a bool # a bool
...@@ -642,15 +642,6 @@ class ConvOp(OpenMPOp): ...@@ -642,15 +642,6 @@ class ConvOp(OpenMPOp):
# infer output shape from what we have # infer output shape from what we have
outshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx, self.dy), outshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx, self.dy),
self.out_mode) self.out_mode)
if not self.has_all_shape(self.imshp_logical, self.kshp_logical,
self.bsize, self.nkern):
# FIXME: Not sure why this is needed. I think the shape is inferred
# correctly no matter what, but if we return a partially symbolic
# shape here, test_conv_cuda_ndarray:test_gemm_grads fails. (@f0k)
raise theano.tensor.ShapeError()
# FIXME: Actually, test_conv_cuda_ndarray:test_gemm_grads only passes if
# we completely disable shape inference. (@f0k)
raise theano.tensor.ShapeError()
return [(bsize, nkern) + outshp] return [(bsize, nkern) + outshp]
def perform(self, node, inp, out): def perform(self, node, inp, out):
...@@ -947,8 +938,8 @@ class ConvOp(OpenMPOp): ...@@ -947,8 +938,8 @@ class ConvOp(OpenMPOp):
din = din(gz, filters) din = din(gz, filters)
assert (all(shp is None for shp in din.owner.op.outshp) or assert all(o is None or o == i
all(o == i for o, i in zip(din.owner.op.outshp, self.imshp[1:]))) for o, i in zip(din.owner.op.outshp, self.imshp[1:]))
# din and dw should have the same broadcasting pattern as the # din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns). # parameters they are the gradient of (resp. inputs and kerns).
...@@ -1035,8 +1026,9 @@ using namespace std; ...@@ -1035,8 +1026,9 @@ using namespace std;
d = locals() d = locals()
d.update(sub) d.update(sub)
all_shape = self.has_all_shape(self.imshp, self.kshp, all_shape = (self.has_all_shape(self.imshp, self.kshp,
self.nkern, self.bsize) self.nkern, self.bsize) and
self.has_all_shape(self.imshp_logical, self.kshp_logical))
d["self_out_mode"] = self.out_mode d["self_out_mode"] = self.out_mode
d["self_dx"] = self.dx d["self_dx"] = self.dx
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论