提交 1563ea38 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2103 from f0k/convop-grad-shapes

Propagate shape information to ConvOp gradients
...@@ -1699,6 +1699,13 @@ def local_gpualloc_memset_0(node): ...@@ -1699,6 +1699,13 @@ def local_gpualloc_memset_0(node):
inp.data.size == 1 and inp.data.size == 1 and
(numpy.asarray(inp.data) == 0).all()): (numpy.asarray(inp.data) == 0).all()):
new_out = GpuAlloc(memset_0=True)(*node.inputs) new_out = GpuAlloc(memset_0=True)(*node.inputs)
old_bcast = node.outputs[0].type.broadcastable
if new_out.type.broadcastable != old_bcast:
# check that we did not try discarding a broadcastable dimension
assert not any(b_old and not b_new for b_old, b_new in zip(
old_bcast, new_out.type.broadcastable))
# force old broadcasting pattern; we must not change it here
new_out = tensor.patternbroadcast(new_out, old_bcast)
return [new_out] return [new_out]
......
...@@ -242,34 +242,30 @@ class ConvOp(OpenMPOp): ...@@ -242,34 +242,30 @@ class ConvOp(OpenMPOp):
speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812] speed_unroll_patch_shape = [1.2967290878295898, 5.5283889770507812]
@staticmethod @staticmethod
def has_all_shape(imshp, kshp, nkern, bsize): def has_all_shape(imshp, kshp, nkern=1, bsize=1):
all_shape = (imshp is not None and kshp is not None and return (nkern is not None and bsize is not None and
nkern is not None and bsize is not None) all(shp is not None for shp in imshp) and
if (all_shape and all(shp is not None for shp in kshp))
(any([True for sh in imshp if sh is None]) or
any([True for sh in kshp if sh is None]))):
all_shape = False
return all_shape
@staticmethod @staticmethod
def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'): def getOutputShape(inshp, kshp, stride=(1, 1), mode='valid'):
""" """
Computes the output dimensions of convolving an image of shape "inshp" Computes the output dimensions of convolving an image of shape "inshp"
with kernels of shape "kshp". with kernels of shape "kshp". Accepts symbolic or integer shapes.
Propagates `None`s (for unknown shapes).
:param inshp: (rows,cols) of input image :param inshp: (rows,cols) of input image
:param kshp: (rows,cols) of filters :param kshp: (rows,cols) of filters
:param mode: 'valid' or 'full' (see 'border_mode' in conv2d's doc) :param mode: 'valid' or 'full' (see 'border_mode' in conv2d's doc)
:return: (rows,cols) of output image :return: (rows,cols) of output image
""" """
dx, dy = stride # The formula would be ceil((i + s * k - s * 1) / float(d)),
if mode == 'valid': # with s=1 for mode=='full' and s=-1 for mode=='valid'.
s = -1 # To support symbolic shapes, we express this with integer arithmetics.
else: return tuple(None if i is None or k is None
s = 1 else ((i - k) // d + 1) if mode == 'valid'
inshp, kshp = numpy.array(inshp), numpy.array(kshp) else ((i + k + d - 2) // d)
return numpy.int64(numpy.ceil((inshp + s * kshp - s * 1) / for i, k, d in zip(inshp, kshp, stride))
numpy.array([dx, dy], dtype='float')))
def __init__(self, imshp=None, kshp=None, nkern=None, bsize=None, def __init__(self, imshp=None, kshp=None, nkern=None, bsize=None,
dx=1, dy=1, dx=1, dy=1,
...@@ -287,7 +283,7 @@ class ConvOp(OpenMPOp): ...@@ -287,7 +283,7 @@ class ConvOp(OpenMPOp):
""" """
Initializes a ConvOp with given output_mode (full/valid). All other Initializes a ConvOp with given output_mode (full/valid). All other
parameters are optional and are only used to generate more optimized c parameters are optional and are only used to generate more optimized c
code. code, or to enable graph optimizers to optimally replace the ConvOp.
NOTES ON OPTIMIZATION: NOTES ON OPTIMIZATION:
Their is two type of optimization. The first is the selection of the Their is two type of optimization. The first is the selection of the
...@@ -368,13 +364,31 @@ class ConvOp(OpenMPOp): ...@@ -368,13 +364,31 @@ class ConvOp(OpenMPOp):
Set to False in the grad again the weight when the Set to False in the grad again the weight when the
output_mode is full. output_mode is full.
""" """
# Desactivate fft_optimization at the op level if specified # Deactivate fft_optimization at the op level if specified
if version == "no_fft": if version == "no_fft":
self.fft_opt = False self.fft_opt = False
version = -1 version = -1
else: else:
self.fft_opt = True self.fft_opt = True
# Expand unknown image / kernel shapes into tuples of Nones
if imshp is None:
imshp = (None, None, None)
else:
imshp = tuple(imshp)
if kshp is None:
kshp = (None, None)
else:
kshp = tuple(kshp)
# Check imshp and kshp dimensionality
if len(imshp) == 2:
imshp = (1,) + imshp
elif len(imshp) != 3:
raise ValueError("len(imshp) must be 2 or 3, got %d" % len(imshp))
if len(kshp) != 2:
raise ValueError("len(kshp) must be 2, got %d" % len(kshp))
# We must continue to consider None as 1 for backward compatibility. # We must continue to consider None as 1 for backward compatibility.
if dx is None: if dx is None:
dx = 1 dx = 1
...@@ -390,32 +404,17 @@ class ConvOp(OpenMPOp): ...@@ -390,32 +404,17 @@ class ConvOp(OpenMPOp):
dy = int(dy) dy = int(dy)
all_shape = self.has_all_shape(imshp, kshp, nkern, bsize) all_shape = self.has_all_shape(imshp, kshp, nkern, bsize)
if (unroll_batch or unroll_kern) and not all_shape: if (unroll_batch or unroll_kern) and not all_shape:
raise Exception("In ConvOp, when using unroll_batch and" raise Exception("In ConvOp, when using unroll_batch and"
" unroll_nkern, all shape are needed") " unroll_nkern, all shape are needed")
#Init the openmp attribute #Init the openmp attribute
super(ConvOp, self).__init__(openmp=openmp) super(ConvOp, self).__init__(openmp=openmp)
if not all_shape or self.openmp: if not all_shape or self.openmp:
# Only this version is parallelized # Only this version is parallelized
unroll_patch = True unroll_patch = True
if imshp is not None:
imshp = tuple(imshp)
if len(imshp) == 2:
imshp = (1,) + imshp
elif len(imshp) == 3:
imshp = imshp
else:
raise Exception("bad len for imshp")
self.imshp = imshp self.imshp = imshp
if kshp is not None:
kshp = tuple(kshp)
self.kshp = kshp self.kshp = kshp
self.nkern = nkern self.nkern = nkern
self.bsize = bsize self.bsize = bsize
...@@ -425,16 +424,24 @@ class ConvOp(OpenMPOp): ...@@ -425,16 +424,24 @@ class ConvOp(OpenMPOp):
self.version = version self.version = version
# a triple # a triple
self.imshp_logical = self.imshp if imshp_logical is None:
if imshp_logical is not None: self.imshp_logical = self.imshp
self.imshp_logical = tuple(imshp_logical) else:
assert ((self.imshp is None and self.imshp_logical is None) or imshp_logical = tuple(imshp_logical)
(len(self.imshp) == len(self.imshp_logical))) if len(imshp_logical) != 3:
raise ValueError("len(imshp_logical) must be 3, got %d" % len(imshp_logical))
self.imshp_logical = imshp_logical
# a pair # a pair
self.kshp_logical = self.kshp if kshp_logical is None:
if kshp_logical is not None: self.kshp_logical = self.kshp
self.kshp_logical = tuple(kshp_logical) else:
kshp_logical = tuple(kshp_logical)
if len(kshp_logical) != 2:
raise ValueError("len(kshp_logical) must be 2, got %d" % len(kshp_logical))
self.kshp_logical = kshp_logical
# a bool
self.kshp_logical_top_aligned = kshp_logical_top_aligned self.kshp_logical_top_aligned = kshp_logical_top_aligned
self.unroll_batch = unroll_batch self.unroll_batch = unroll_batch
...@@ -485,23 +492,19 @@ class ConvOp(OpenMPOp): ...@@ -485,23 +492,19 @@ class ConvOp(OpenMPOp):
_logger.warn(warnstr, self.unroll_kern, self.nkern, new) _logger.warn(warnstr, self.unroll_kern, self.nkern, new)
self.unroll_kern = new self.unroll_kern = new
if all_shape: self.outshp = ConvOp.getOutputShape(self.imshp_logical[1:],
self.outshp = ConvOp.getOutputShape(self.imshp_logical[1:], self.kshp_logical, (dx, dy),
self.kshp_logical, (dx, dy), output_mode)
self.fulloutshp = ConvOp.getOutputShape(self.imshp_logical[1:],
self.kshp_logical, (1, 1),
output_mode) output_mode)
self.fulloutshp = ConvOp.getOutputShape(self.imshp_logical[1:],
self.kshp_logical, (1, 1),
output_mode)
else:
self.outshp = None
self.fulloutshp = None
self.out_mode = output_mode self.out_mode = output_mode
if not self.out_mode in ["valid", "full"]: if not self.out_mode in ["valid", "full"]:
raise Exception("Mode %s not implemented" % self.out_mode) raise Exception("Mode %s not implemented" % self.out_mode)
if self.outshp is not None and not (self.outshp > 0).all(): if any((shp is not None) and (shp <= 0) for shp in self.outshp):
raise Exception("Bad size for the output shape. Verify that [post-" raise Exception("Bad size for the output shape. Verify that [post-"
"supersampling] input shape (%s) and kern" "supersampling] input shape (%s) and kern"
" shape(%s) are ok. (Hint: kerns must fit inside" " shape(%s) are ok. (Hint: kerns must fit inside"
...@@ -518,14 +521,10 @@ class ConvOp(OpenMPOp): ...@@ -518,14 +521,10 @@ class ConvOp(OpenMPOp):
elif self.bsize is not None and self.nkern is not None: elif self.bsize is not None and self.nkern is not None:
bsize = self.bsize bsize = self.bsize
nkern = self.nkern nkern = self.nkern
if bsize is None:
bsize = 1
if nkern is None:
nkern = 1
mode_idx = 0 mode_idx = 0
if self.out_mode != "valid": if self.out_mode != "valid":
mode_idx = 1 mode_idx = 1
if all_shape: if self.has_all_shape(self.imshp, self.kshp):
time_unroll_patch = self.speed_unroll_patch_shape[mode_idx] time_unroll_patch = self.speed_unroll_patch_shape[mode_idx]
else: else:
time_unroll_patch = self.speed_unroll_patch_noshape[ time_unroll_patch = self.speed_unroll_patch_noshape[
...@@ -619,10 +618,7 @@ class ConvOp(OpenMPOp): ...@@ -619,10 +618,7 @@ class ConvOp(OpenMPOp):
raise NotImplementedError( raise NotImplementedError(
"The image and the kernel must have the same type." "The image and the kernel must have the same type."
"inputs(%s), kerns(%s)" % (_inputs.dtype, _kerns.dtype)) "inputs(%s), kerns(%s)" % (_inputs.dtype, _kerns.dtype))
if self.outshp is not None: bcastable23 = [self.outshp[0] == 1, self.outshp[1] == 1]
bcastable23 = [self.outshp[0] == 1, self.outshp[1] == 1]
else:
bcastable23 = [False, False]
output = theano.tensor.tensor(dtype=_inputs.type.dtype, output = theano.tensor.tensor(dtype=_inputs.type.dtype,
broadcastable=[_inputs.broadcastable[0], broadcastable=[_inputs.broadcastable[0],
_kerns.broadcastable[0]] + _kerns.broadcastable[0]] +
...@@ -631,32 +627,25 @@ class ConvOp(OpenMPOp): ...@@ -631,32 +627,25 @@ class ConvOp(OpenMPOp):
return Apply(self, [_inputs, _kerns], [output]) return Apply(self, [_inputs, _kerns], [output])
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
imshp = input_shapes[0] imshp = input_shapes[0] # 4D image shape
kshp = input_shapes[1] kshp = input_shapes[1] # 4D filter shape
bsize, imshp = imshp[0], list(imshp[1:])
batch_size = imshp[0] nkern, kshp = kshp[0], list(kshp[2:])
fmo = kshp[0] # replace symbolic shapes with known shapes
if self.bsize is not None:
if self.imshp is not None and self.kshp is not None: bsize = self.bsize
imshp = self.imshp for i in [0, 1, 2]:
kshp = self.kshp if self.imshp_logical[i] is not None:
if self.imshp_logical: imshp[i] = self.imshp_logical[i]
imshp = self.imshp_logical if self.nkern is not None:
if self.kshp_logical: nkern = self.nkern
kshp = self.kshp_logical for i in [0, 1]:
try: if self.kshp_logical[i] is not None:
fmshp = ConvOp.getOutputShape(imshp[1:], kshp[i] = self.kshp_logical[i]
kshp, (self.dx, self.dy), # infer output shape from what we have
self.out_mode) outshp = ConvOp.getOutputShape(imshp[1:], kshp, (self.dx, self.dy),
except TypeError: self.out_mode)
raise theano.tensor.ShapeError() return [(bsize, nkern) + outshp]
outshp = (batch_size, fmo) + tuple(fmshp)
return [outshp]
else:
# Haven't implemented this case. imshp and kshp may be symbollic
# and ConvOp.getOutputShape doesn't handle this. In this case
# we simply let the default function do its work.
raise theano.tensor.ShapeError()
def perform(self, node, inp, out): def perform(self, node, inp, out):
""" """
...@@ -674,10 +663,10 @@ class ConvOp(OpenMPOp): ...@@ -674,10 +663,10 @@ class ConvOp(OpenMPOp):
# TODO: move these back out to global scope when they no longer # TODO: move these back out to global scope when they no longer
# cause an atexit error # cause an atexit error
imshp = self.imshp imshp = self.imshp
if imshp is None or any([x is None for x in imshp]): if any(x is None for x in imshp):
imshp = tuple(img2d.shape[1:]) imshp = tuple(img2d.shape[1:])
kshp = self.kshp kshp = self.kshp
if kshp is None or any([x is None for x in kshp]): if any(x is None for x in kshp):
kshp = tuple(filtersflipped.shape[2:]) kshp = tuple(filtersflipped.shape[2:])
bsize = self.bsize bsize = self.bsize
if bsize is None: if bsize is None:
...@@ -687,24 +676,22 @@ class ConvOp(OpenMPOp): ...@@ -687,24 +676,22 @@ class ConvOp(OpenMPOp):
nkern = filtersflipped.shape[0] nkern = filtersflipped.shape[0]
imshp_logical = self.imshp_logical imshp_logical = self.imshp_logical
if imshp_logical is None: if imshp_logical[0] is None:
imshp_logical = imshp imshp_logical = (imshp[0],) + imshp_logical[1:]
if numpy.any([x is None for x in imshp_logical]): if imshp_logical[1] is None:
imshp_logical = tuple(img2d.shape[1:]) imshp_logical = (imshp_logical[0], imshp[1], imshp_logical[2])
if imshp_logical[2] is None:
imshp_logical = imshp_logical[:2] + (imshp[2],)
assert all(x is not None for x in imshp_logical)
kshp_logical = self.kshp_logical kshp_logical = self.kshp_logical
if kshp_logical is None: if kshp_logical[0] is None:
kshp_logical = kshp kshp_logical = (kshp[0], kshp_logical[1])
else: if kshp_logical[1] is None:
if kshp_logical[0] is None: kshp_logical = (kshp_logical[0], kshp[1])
kshp_logical = (kshp[0], kshp_logical[1]) assert all(x is not None for x in kshp_logical)
if kshp_logical[1] is None:
kshp_logical = (kshp_logical[0], kshp[1])
if numpy.any([x is None for x in kshp_logical]): if all(shp is not None for shp in self.fulloutshp):
kshp = tuple(filtersflipped.shape[2:])
if self.fulloutshp is not None:
fulloutshp = tuple(self.fulloutshp) fulloutshp = tuple(self.fulloutshp)
else: else:
fulloutshp = tuple(ConvOp.getOutputShape(imshp_logical[ fulloutshp = tuple(ConvOp.getOutputShape(imshp_logical[
...@@ -843,19 +830,14 @@ class ConvOp(OpenMPOp): ...@@ -843,19 +830,14 @@ class ConvOp(OpenMPOp):
newin = inputs.dimshuffle((1, 0, 2, 3)) newin = inputs.dimshuffle((1, 0, 2, 3))
newgz = gz.dimshuffle((1, 0, 2, 3)) newgz = gz.dimshuffle((1, 0, 2, 3))
(bsize, nkern) = None, None
imshp = None
kshp = None
un_p = self.unroll_patch un_p = self.unroll_patch
imshp_logical = None
if self.out_mode == 'valid': if self.out_mode == 'valid':
(img, filters) = (newin, newgz) (img, filters) = (newin, newgz)
kshp_logical = self.fulloutshp kshp_logical = self.fulloutshp
kshp_logical_top_aligned = False kshp_logical_top_aligned = False
if all_shape: imshp_logical = None
(bsize, nkern) = (self.imshp[0], self.nkern) (bsize, nkern) = (self.imshp[0], self.nkern)
imshp = (self.bsize, self.imshp[1], self.imshp[2]) imshp = (self.bsize, self.imshp[1], self.imshp[2])
kshp = self.outshp kshp = self.outshp
un_b = self.unroll_batch un_b = self.unroll_batch
un_k = self.unroll_kern un_k = self.unroll_kern
...@@ -863,13 +845,12 @@ class ConvOp(OpenMPOp): ...@@ -863,13 +845,12 @@ class ConvOp(OpenMPOp):
(img, filters) = (newgz, newin) (img, filters) = (newgz, newin)
kshp_logical = None kshp_logical = None
kshp_logical_top_aligned = True kshp_logical_top_aligned = True
if all_shape: imshp_logical = (self.bsize,
imshp_logical = (self.bsize, self.fulloutshp[0],
self.fulloutshp[0], self.fulloutshp[1])
self.fulloutshp[1]) (bsize, nkern) = (self.nkern, self.imshp[0])
(bsize, nkern) = (self.nkern, self.imshp[0]) imshp = (self.bsize, self.outshp[0], self.outshp[1])
imshp = (self.bsize, self.outshp[0], self.outshp[1]) kshp = self.imshp[1:]
kshp = self.imshp[1:]
un_b = self.unroll_kern un_b = self.unroll_kern
un_k = self.unroll_batch un_k = self.unroll_batch
else: else:
...@@ -920,7 +901,7 @@ class ConvOp(OpenMPOp): ...@@ -920,7 +901,7 @@ class ConvOp(OpenMPOp):
dw = dw(img, filters) dw = dw(img, filters)
if all_shape: if all_shape:
assert (dw.owner.op.outshp == self.kshp).all() assert all(o == k for o, k in zip(dw.owner.op.outshp, self.kshp))
if self.out_mode == 'valid': if self.out_mode == 'valid':
# before DimShuffle, dw is of shape visdim x nkern x kshp[0] x kshp[1] # before DimShuffle, dw is of shape visdim x nkern x kshp[0] x kshp[1]
dw = dw.dimshuffle((1, 0, 2, 3)) dw = dw.dimshuffle((1, 0, 2, 3))
...@@ -933,16 +914,11 @@ class ConvOp(OpenMPOp): ...@@ -933,16 +914,11 @@ class ConvOp(OpenMPOp):
filters = kerns.dimshuffle((1, 0, 2, 3)) filters = kerns.dimshuffle((1, 0, 2, 3))
filters = filters[:, :, ::-1, ::-1] filters = filters[:, :, ::-1, ::-1]
nkern = None
imshp = None
imshp_logical = None
kshp = None
if all_shape: nkern = self.imshp[0]
nkern = self.imshp[0] imshp = (self.nkern, self.outshp[0], self.outshp[1])
imshp = (self.nkern, self.outshp[0], self.outshp[1]) imshp_logical = (self.nkern, self.fulloutshp[0],
imshp_logical = (self.nkern, self.fulloutshp[0], self.fulloutshp[1])
self.fulloutshp[1])
if 0: # hard-code c generation parameters if 0: # hard-code c generation parameters
din = ConvOp(imshp, self.kshp, nkern, self.bsize, din = ConvOp(imshp, self.kshp, nkern, self.bsize,
...@@ -965,9 +941,8 @@ class ConvOp(OpenMPOp): ...@@ -965,9 +941,8 @@ class ConvOp(OpenMPOp):
din = din(gz, filters) din = din(gz, filters)
assert (din.owner.op.outshp is None and self.imshp is None) or \ assert all(o is None or o == i
(din.owner.op.outshp is None) or \ for o, i in zip(din.owner.op.outshp, self.imshp[1:]))
(din.owner.op.outshp == self.imshp[1:]).all()
# din and dw should have the same broadcasting pattern as the # din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns). # parameters they are the gradient of (resp. inputs and kerns).
...@@ -1054,8 +1029,9 @@ using namespace std; ...@@ -1054,8 +1029,9 @@ using namespace std;
d = locals() d = locals()
d.update(sub) d.update(sub)
all_shape = self.has_all_shape(self.imshp, self.kshp, all_shape = (self.has_all_shape(self.imshp, self.kshp,
self.nkern, self.bsize) self.nkern, self.bsize) and
self.has_all_shape(self.imshp_logical, self.kshp_logical))
d["self_out_mode"] = self.out_mode d["self_out_mode"] = self.out_mode
d["self_dx"] = self.dx d["self_dx"] = self.dx
...@@ -1075,23 +1051,23 @@ using namespace std; ...@@ -1075,23 +1051,23 @@ using namespace std;
d["self_kshp1"] = "PyArray_DIMS(%(filtersflipped)s)[3]" % d d["self_kshp1"] = "PyArray_DIMS(%(filtersflipped)s)[3]" % d
# Override the default value if we have it # Override the default value if we have it
if self.kshp is not None and self.kshp[0]: if self.kshp[0] is not None:
d["self_kshp0"] = self.kshp[0] d["self_kshp0"] = self.kshp[0]
if self.kshp is not None and self.kshp[1]: if self.kshp[1] is not None:
d["self_kshp1"] = self.kshp[1] d["self_kshp1"] = self.kshp[1]
if self.outshp is not None and self.outshp[0]: if self.outshp[0] is not None:
d["self_outshp0"] = self.outshp[0] d["self_outshp0"] = self.outshp[0]
if self.outshp is not None and self.outshp[1]: if self.outshp[1] is not None:
d["self_outshp1"] = self.outshp[1] d["self_outshp1"] = self.outshp[1]
if self.imshp is not None and self.imshp[0]: if self.imshp[0] is not None:
d["self_imshp0"] = self.imshp[0] d["self_imshp0"] = self.imshp[0]
if self.imshp is not None and self.imshp[1]: if self.imshp[1] is not None:
d["self_imshp1"] = self.imshp[1] d["self_imshp1"] = self.imshp[1]
if self.imshp is not None and self.imshp[2]: if self.imshp[2] is not None:
d["self_imshp2"] = self.imshp[2] d["self_imshp2"] = self.imshp[2]
if self.bsize: if self.bsize is not None:
d["self_bsize"] = self.bsize d["self_bsize"] = self.bsize
if self.nkern: if self.nkern is not None:
d["self_nkern"] = self.nkern d["self_nkern"] = self.nkern
# Other hard coded stuff only if we have all shapes # Other hard coded stuff only if we have all shapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论