提交 039f7b5c authored 作者: abergeron's avatar abergeron

Merge pull request #2007 from nouiz/conv2d_none

[CRASH] Conv2d none
...@@ -241,6 +241,16 @@ class ConvOp(OpenMPOp): ...@@ -241,6 +241,16 @@ class ConvOp(OpenMPOp):
#valid time, full time #valid time, full time
speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812] speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812]
@staticmethod
def has_all_shape(imshp, kshp, nkern, bsize):
all_shape = (imshp is not None and kshp is not None and
nkern is not None and bsize is not None)
if (all_shape and
(any([True for sh in imshp if sh is None]) or
any([True for sh in kshp if sh is None]))):
all_shape = False
return all_shape
@staticmethod @staticmethod
def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'): def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'):
""" """
...@@ -371,8 +381,7 @@ class ConvOp(OpenMPOp): ...@@ -371,8 +381,7 @@ class ConvOp(OpenMPOp):
raise TypeError('ConvOp.__init__ param dy must be an int', dy) raise TypeError('ConvOp.__init__ param dy must be an int', dy)
dy = int(dy) dy = int(dy)
all_shape = imshp is not None and kshp is not None and \ all_shape = self.has_all_shape(imshp, kshp, nkern, bsize)
nkern is not None and bsize is not None
if (unroll_batch or unroll_kern) and not all_shape: if (unroll_batch or unroll_kern) and not all_shape:
raise Exception("In ConvOp, when using unroll_batch and" raise Exception("In ConvOp, when using unroll_batch and"
...@@ -484,7 +493,7 @@ class ConvOp(OpenMPOp): ...@@ -484,7 +493,7 @@ class ConvOp(OpenMPOp):
if not self.out_mode in ["valid", "full"]: if not self.out_mode in ["valid", "full"]:
raise Exception("Mode %s not implemented" % self.out_mode) raise Exception("Mode %s not implemented" % self.out_mode)
if all_shape and not (self.outshp > 0).all(): if self.outshp is not None and not (self.outshp > 0).all():
raise Exception("Bad size for the output shape. Verify that [post-" raise Exception("Bad size for the output shape. Verify that [post-"
"supersampling] input shape (%s) and kern" "supersampling] input shape (%s) and kern"
" shape(%s) are ok. (Hint: kerns must fit inside" " shape(%s) are ok. (Hint: kerns must fit inside"
...@@ -807,8 +816,8 @@ class ConvOp(OpenMPOp): ...@@ -807,8 +816,8 @@ class ConvOp(OpenMPOp):
"ERROR: We disable ConvOp.grad now when dx or " "ERROR: We disable ConvOp.grad now when dx or "
"dy are different from 1 and 2, as there is a bug in it.") "dy are different from 1 and 2, as there is a bug in it.")
all_shape = (self.imshp is not None and self.kshp is not None and all_shape = self.has_all_shape(self.imshp, self.kshp,
self.nkern is not None and self.bsize is not None) self.nkern, self.bsize)
if not all_shape and (self.dx != 1 or self.dy != 1): if not all_shape and (self.dx != 1 or self.dy != 1):
raise Exception("ConvOp.grad when dx!=1 or dy!=1 we must have all " raise Exception("ConvOp.grad when dx!=1 or dy!=1 we must have all "
...@@ -956,7 +965,7 @@ class ConvOp(OpenMPOp): ...@@ -956,7 +965,7 @@ class ConvOp(OpenMPOp):
return ['<numpy/noprefix.h>', '<iostream>', '<sstream>'] return ['<numpy/noprefix.h>', '<iostream>', '<sstream>']
def c_code_cache_version(self): def c_code_cache_version(self):
return (10, self.openmp, blas.blas_header_version()) return (11, self.openmp, blas.blas_header_version())
def c_support_code(self): def c_support_code(self):
return """ return """
...@@ -1031,24 +1040,48 @@ using namespace std; ...@@ -1031,24 +1040,48 @@ using namespace std;
d = locals() d = locals()
d.update(sub) d.update(sub)
all_shape = (self.imshp is not None and self.kshp is not None and all_shape = self.has_all_shape(self.imshp, self.kshp,
self.nkern is not None and self.bsize is not None) self.nkern, self.bsize)
d["self_out_mode"] = self.out_mode d["self_out_mode"] = self.out_mode
d["self_dx"] = self.dx d["self_dx"] = self.dx
d["self_dy"] = self.dy d["self_dy"] = self.dy
d["mode"] = self.out_mode.upper() d["mode"] = self.out_mode.upper()
d["affectation"] = "=" d["affectation"] = "="
if all_shape:
d["self_bsize"] = self.bsize # Default values, will be overrided if the shape info is provided
d["self_nkern"] = self.nkern d["self_bsize"] = "PyArray_DIMS(%(img2d)s)[0]" % d
d["self_nkern"] = "PyArray_DIMS(%(filtersflipped)s)[0]" % d
d["self_outshp0"] = "-1"
d["self_outshp1"] = "-1"
d["self_imshp0"] = "PyArray_DIMS(%(img2d)s)[1]" % d
d["self_imshp1"] = "PyArray_DIMS(%(img2d)s)[2]" % d
d["self_imshp2"] = "PyArray_DIMS(%(img2d)s)[3]" % d
d["self_kshp0"] = "PyArray_DIMS(%(filtersflipped)s)[2]" % d
d["self_kshp1"] = "PyArray_DIMS(%(filtersflipped)s)[3]" % d
# Override the default value if we have it
if self.kshp is not None and self.kshp[0]:
d["self_kshp0"] = self.kshp[0]
if self.kshp is not None and self.kshp[1]:
d["self_kshp1"] = self.kshp[1]
if self.outshp is not None and self.outshp[0]:
d["self_outshp0"] = self.outshp[0] d["self_outshp0"] = self.outshp[0]
if self.outshp is not None and self.outshp[1]:
d["self_outshp1"] = self.outshp[1] d["self_outshp1"] = self.outshp[1]
if self.imshp is not None and self.imshp[0]:
d["self_imshp0"] = self.imshp[0] d["self_imshp0"] = self.imshp[0]
if self.imshp is not None and self.imshp[1]:
d["self_imshp1"] = self.imshp[1] d["self_imshp1"] = self.imshp[1]
if self.imshp is not None and self.imshp[2]:
d["self_imshp2"] = self.imshp[2] d["self_imshp2"] = self.imshp[2]
d["self_kshp0"] = self.kshp[0] if self.bsize:
d["self_kshp1"] = self.kshp[1] d["self_bsize"] = self.bsize
if self.nkern:
d["self_nkern"] = self.nkern
# Other hard coded stuff only if we have all shapes
if all_shape:
d["self_kshp_logical_r"] = self.kshp_logical[0] d["self_kshp_logical_r"] = self.kshp_logical[0]
d["self_kshp_logical_c"] = self.kshp_logical[1] d["self_kshp_logical_c"] = self.kshp_logical[1]
d["self_kshp_logical_stride_r"] = int(numpy.ceil( d["self_kshp_logical_stride_r"] = int(numpy.ceil(
...@@ -1149,15 +1182,6 @@ if(kerns_dim[3] %% %(self_kshp1)s!=0){ ...@@ -1149,15 +1182,6 @@ if(kerns_dim[3] %% %(self_kshp1)s!=0){
""" % (locals()) """ % (locals())
else: else:
d["self_bsize"] = "PyArray_DIMS(%(img2d)s)[0]" % d
d["self_nkern"] = "PyArray_DIMS(%(filtersflipped)s)[0]" % d
d["self_outshp0"] = "-1"
d["self_outshp1"] = "-1"
d["self_imshp0"] = "PyArray_DIMS(%(img2d)s)[1]" % d
d["self_imshp1"] = "PyArray_DIMS(%(img2d)s)[2]" % d
d["self_imshp2"] = "PyArray_DIMS(%(img2d)s)[3]" % d
d["self_kshp0"] = "PyArray_DIMS(%(filtersflipped)s)[2]" % d
d["self_kshp1"] = "PyArray_DIMS(%(filtersflipped)s)[3]" % d
d["affectation"] = "+=" d["affectation"] = "+="
d["all_shape"] = "0" d["all_shape"] = "0"
d["dim_zz_const"] = "" d["dim_zz_const"] = ""
......
...@@ -359,6 +359,12 @@ class TestConv2D(utt.InferShapeTester): ...@@ -359,6 +359,12 @@ class TestConv2D(utt.InferShapeTester):
self.validate((None, 2, None, None), (None, 2, 5, 5), self.validate((None, 2, None, None), (None, 2, 5, 5),
N_image_shape=(3, 2, 8, 8), N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5)) N_filter_shape=(4, 2, 5, 5))
self.validate((3, 2, 8, 8), (4, 2, None, 5),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
self.validate((3, 2, 8, 8), (4, 2, 5, None),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
def test_wrong_info(self): def test_wrong_info(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论