提交 fb403768 authored 作者: Frederic's avatar Frederic

Do as the doc say, don't crash when there is None in the shape info elements.

上级 97d3b7a9
...@@ -241,6 +241,16 @@ class ConvOp(OpenMPOp): ...@@ -241,6 +241,16 @@ class ConvOp(OpenMPOp):
#valid time, full time #valid time, full time
speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812] speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812]
@staticmethod
def has_all_shape(imshp, kshp, nkern, bsize):
all_shape = (imshp is not None and kshp is not None and
nkern is not None and bsize is not None)
if (all_shape and
(any([True for sh in imshp if sh is None]) or
any([True for sh in kshp if sh is None]))):
all_shape = False
return all_shape
@staticmethod @staticmethod
def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'): def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'):
""" """
...@@ -371,8 +381,7 @@ class ConvOp(OpenMPOp): ...@@ -371,8 +381,7 @@ class ConvOp(OpenMPOp):
raise TypeError('ConvOp.__init__ param dy must be an int', dy) raise TypeError('ConvOp.__init__ param dy must be an int', dy)
dy = int(dy) dy = int(dy)
all_shape = imshp is not None and kshp is not None and \ all_shape = self.has_all_shape(imshp, kshp, nkern, bsize)
nkern is not None and bsize is not None
if (unroll_batch or unroll_kern) and not all_shape: if (unroll_batch or unroll_kern) and not all_shape:
raise Exception("In ConvOp, when using unroll_batch and" raise Exception("In ConvOp, when using unroll_batch and"
...@@ -484,7 +493,7 @@ class ConvOp(OpenMPOp): ...@@ -484,7 +493,7 @@ class ConvOp(OpenMPOp):
if not self.out_mode in ["valid", "full"]: if not self.out_mode in ["valid", "full"]:
raise Exception("Mode %s not implemented" % self.out_mode) raise Exception("Mode %s not implemented" % self.out_mode)
if all_shape and not (self.outshp > 0).all(): if self.outshp is not None and not (self.outshp > 0).all():
raise Exception("Bad size for the output shape. Verify that [post-" raise Exception("Bad size for the output shape. Verify that [post-"
"supersampling] input shape (%s) and kern" "supersampling] input shape (%s) and kern"
" shape(%s) are ok. (Hint: kerns must fit inside" " shape(%s) are ok. (Hint: kerns must fit inside"
...@@ -807,8 +816,8 @@ class ConvOp(OpenMPOp): ...@@ -807,8 +816,8 @@ class ConvOp(OpenMPOp):
"ERROR: We disable ConvOp.grad now when dx or " "ERROR: We disable ConvOp.grad now when dx or "
"dy are different from 1 and 2, as there is a bug in it.") "dy are different from 1 and 2, as there is a bug in it.")
all_shape = (self.imshp is not None and self.kshp is not None and all_shape = self.has_all_shape(self.imshp, self.kshp,
self.nkern is not None and self.bsize is not None) self.nkern, self.bsize)
if not all_shape and (self.dx != 1 or self.dy != 1): if not all_shape and (self.dx != 1 or self.dy != 1):
raise Exception("ConvOp.grad when dx!=1 or dy!=1 we must have all " raise Exception("ConvOp.grad when dx!=1 or dy!=1 we must have all "
...@@ -1031,8 +1040,8 @@ using namespace std; ...@@ -1031,8 +1040,8 @@ using namespace std;
d = locals() d = locals()
d.update(sub) d.update(sub)
all_shape = (self.imshp is not None and self.kshp is not None and all_shape = self.has_all_shape(self.imshp, self.kshp,
self.nkern is not None and self.bsize is not None) self.nkern, self.bsize)
d["self_out_mode"] = self.out_mode d["self_out_mode"] = self.out_mode
d["self_dx"] = self.dx d["self_dx"] = self.dx
......
...@@ -359,6 +359,12 @@ class TestConv2D(utt.InferShapeTester): ...@@ -359,6 +359,12 @@ class TestConv2D(utt.InferShapeTester):
self.validate((None, 2, None, None), (None, 2, 5, 5), self.validate((None, 2, None, None), (None, 2, 5, 5),
N_image_shape=(3, 2, 8, 8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5)) N_filter_shape=(4, 2, 5, 5))
self.validate((3, 2, 8, 8), (4, 2, None, 5),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
self.validate((3, 2, 8, 8), (4, 2, 5, None),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
def test_wrong_info(self): def test_wrong_info(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论