fixed python implement of ConvOp (bug with mutable parametrs being stored in

self). Also did some housecleaning
上级 a78299b3
...@@ -19,13 +19,16 @@ class ConvOp(Op): ...@@ -19,13 +19,16 @@ class ConvOp(Op):
""" """
def __init__(self, imshp, kshp, nkern, bsize, dx, dy, output_mode='valid'): def __init__(self, imshp, kshp, nkern, bsize, dx, dy, output_mode='valid'):
imshp = tuple(imshp)
if len(imshp)==2: if len(imshp)==2:
self.imshp = (1,)+imshp self.imshp = (1,)+imshp
elif len(imshp)==3: elif len(imshp)==3:
self.imshp = imshp self.imshp = imshp
else: else:
raise Exception("bad len for imshp") raise Exception("bad len for imshp")
self.kshp = kshp
self.kshp = tuple(kshp)
self.nkern = nkern self.nkern = nkern
self.bsize=bsize self.bsize=bsize
self.dx=dx self.dx=dx
...@@ -33,15 +36,13 @@ class ConvOp(Op): ...@@ -33,15 +36,13 @@ class ConvOp(Op):
if self.dx!=1 or self.dy!=1: if self.dx!=1 or self.dy!=1:
print "Warning, dx!=1 or dy!=1 only supported in python mode!" print "Warning, dx!=1 or dy!=1 only supported in python mode!"
raise NotImplementedError() raise NotImplementedError()
self.outshp = getFilterOutShp(self.imshp, kshp, (dx,dy), output_mode)
self.out_mode = output_mode self.out_mode = output_mode
if not self.out_mode in ["valid", "full"]: if not self.out_mode in ["valid", "full"]:
raise Exception("Mode %s not implemented"%self.out_mode) raise Exception("Mode %s not implemented"%self.out_mode)
self.fulloutshp = N.array(self.imshp[1:]) - N.array(self.kshp) + 1 \
if self.out_mode=='valid'\
else N.array(self.imshp[1:]) + N.array(self.kshp) - 1
assert ((N.array(self.imshp[1:])-self.kshp)>=0).all() assert (self.outshp >= 0).all()
assert N.prod(self.fulloutshp)>0
# def __eq__(self, other): # def __eq__(self, other):
# raise Error("Not implemented") # raise Error("Not implemented")
...@@ -52,8 +53,6 @@ class ConvOp(Op): ...@@ -52,8 +53,6 @@ class ConvOp(Op):
def make_node(self, inputs, kerns): def make_node(self, inputs, kerns):
#all kernels must have the same shape! #all kernels must have the same shape!
#output_mode only valid and full are supported! #output_mode only valid and full are supported!
self.outshp = getFilterOutShp(self.imshp, self.kshp, (self.dx,self.dy), self.out_mode)
self.dtype = inputs.dtype self.dtype = inputs.dtype
assert kerns.dtype==self.dtype assert kerns.dtype==self.dtype
...@@ -68,24 +67,13 @@ class ConvOp(Op): ...@@ -68,24 +67,13 @@ class ConvOp(Op):
By default if len(img2d.shape)==3, we By default if len(img2d.shape)==3, we
""" """
if z[0] is None: if z[0] is None:
z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.fulloutshp)) z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.outshp))
zz=z[0] zz=z[0]
val = _valfrommode(self.out_mode) val = _valfrommode(self.out_mode)
bval = _bvalfromboundary('fill') bval = _bvalfromboundary('fill')
if len(img2d.shape)==2 and self.imshp[0]==1 and self.bsize==1:
img2d = img2d.reshape((1,1)+img2d.shape) img2d = img2d.reshape((self.bsize,)+ self.imshp)
elif len(img2d.shape)==3 and self.imshp[0]==1 and self.bsize!=1: filtersflipped = filtersflipped.reshape((self.nkern,self.imshp[0])+self.kshp)
img2d = img2d.reshape((img2d.shape[0],)+(1,)+img2d.shape[1:])
elif len(img2d.shape)==3:
img2d = img2d.reshape((1,)+(img2d.shape[0],)+img2d.shape[1:])
elif len(img2d.shape)==3 and self.imshp[0]==1 and self.bsize==1:
img2d = img2d.reshape((1,1)+img2d.shape[1:])
elif len(img2d.shape)!=4: raise Exception("bad img2d shape.")
if len(filtersflipped.shape)==3 and self.imshp[0]==1:
assert self.imshp[0]==1
filtersflipped = filtersflipped.reshape((filtersflipped.shape[0],)+(1,)+filtersflipped.shape[1:])
elif len(filtersflipped.shape)!=4: raise Exception("Bad filtersflipped shape")
for b in range(self.bsize): for b in range(self.bsize):
for n in range(self.nkern): for n in range(self.nkern):
...@@ -140,7 +128,6 @@ class ConvOp(Op): ...@@ -140,7 +128,6 @@ class ConvOp(Op):
filters = filters[:,:,::-1,::-1] filters = filters[:,:,::-1,::-1]
nkern = self.imshp[0] nkern = self.imshp[0]
imshp = N.hstack((self.nkern,self.outshp)) imshp = N.hstack((self.nkern,self.outshp))
din = ConvOp(imshp, self.kshp, nkern, self.bsize, din = ConvOp(imshp, self.kshp, nkern, self.bsize,
1,1, output_mode=mode)(gz,filters) 1,1, output_mode=mode)(gz,filters)
...@@ -421,6 +408,11 @@ def convolve2(kerns, kshp, nkern, images, imshp, bsize, step=(1,1), ...@@ -421,6 +408,11 @@ def convolve2(kerns, kshp, nkern, images, imshp, bsize, step=(1,1),
kernrshp = tensor.as_tensor([nkern, nvis_dim] + list(kshp)) kernrshp = tensor.as_tensor([nkern, nvis_dim] + list(kshp))
kerntensor = tensor.reshape(kerns, kernrshp) kerntensor = tensor.reshape(kerns, kernrshp)
print '*** convolve2 ***'
print 'imshp = ', imshp
print 'kshp = ', kshp
print 'nkern = ', nkern
print 'bsize = ', bsize
convop = ConvOp(imshp, kshp, nkern, bsize, 1, 1, output_mode=mode) convop = ConvOp(imshp, kshp, nkern, bsize, 1, 1, output_mode=mode)
convout = convop(imtensor, kerntensor) convout = convop(imtensor, kerntensor)
......
...@@ -308,12 +308,18 @@ class TestConvOp(unittest.TestCase): ...@@ -308,12 +308,18 @@ class TestConvOp(unittest.TestCase):
imgs = T.dmatrix('imgs') imgs = T.dmatrix('imgs')
kerns = T.dmatrix('kerns') kerns = T.dmatrix('kerns')
kshps = [(3,3)]
for mode in 'valid', 'full': for mode in 'valid', 'full':
# 'full' mode should support kernels bigger than the input
if mode == 'full':
kshps.append((12,12))
for imshp in (5,5),(2,5,5),(2,10,10): # (12,10), (3,12,11): for imshp in (5,5),(2,5,5),(2,10,10): # (12,10), (3,12,11):
visdim = 1 if len(imshp)!=3 else imshp[0] visdim = 1 if len(imshp)!=3 else imshp[0]
print 'visdim = ', visdim for kshp in kshps:
for kshp in (3,3),:# (6,7):
imgvals = N.random.random(N.hstack((bsize,imshp))) imgvals = N.random.random(N.hstack((bsize,imshp)))
print 'imgvals.shape = ', imgvals.shape
imgvals = imgvals.reshape(bsize,-1) imgvals = imgvals.reshape(bsize,-1)
if visdim == 1: if visdim == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论