fixed python implement of ConvOp (bug with mutable parametrs being stored in

self). Also did some housecleaning
上级 a78299b3
......@@ -19,13 +19,16 @@ class ConvOp(Op):
"""
def __init__(self, imshp, kshp, nkern, bsize, dx, dy, output_mode='valid'):
imshp = tuple(imshp)
if len(imshp)==2:
self.imshp = (1,)+imshp
elif len(imshp)==3:
self.imshp = imshp
else:
raise Exception("bad len for imshp")
self.kshp = kshp
self.kshp = tuple(kshp)
self.nkern = nkern
self.bsize=bsize
self.dx=dx
......@@ -33,15 +36,13 @@ class ConvOp(Op):
if self.dx!=1 or self.dy!=1:
print "Warning, dx!=1 or dy!=1 only supported in python mode!"
raise NotImplementedError()
self.outshp = getFilterOutShp(self.imshp, kshp, (dx,dy), output_mode)
self.out_mode = output_mode
if not self.out_mode in ["valid", "full"]:
raise Exception("Mode %s not implemented"%self.out_mode)
self.fulloutshp = N.array(self.imshp[1:]) - N.array(self.kshp) + 1 \
if self.out_mode=='valid'\
else N.array(self.imshp[1:]) + N.array(self.kshp) - 1
assert ((N.array(self.imshp[1:])-self.kshp)>=0).all()
assert N.prod(self.fulloutshp)>0
assert (self.outshp >= 0).all()
# def __eq__(self, other):
# raise Error("Not implemented")
......@@ -52,8 +53,6 @@ class ConvOp(Op):
def make_node(self, inputs, kerns):
#all kernels must have the same shape!
#output_mode only valid and full are supported!
self.outshp = getFilterOutShp(self.imshp, self.kshp, (self.dx,self.dy), self.out_mode)
self.dtype = inputs.dtype
assert kerns.dtype==self.dtype
......@@ -68,25 +67,14 @@ class ConvOp(Op):
By default if len(img2d.shape)==3, we
"""
if z[0] is None:
z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.fulloutshp))
z[0] = N.zeros((self.bsize,)+(self.nkern,)+tuple(self.outshp))
zz=z[0]
val = _valfrommode(self.out_mode)
bval = _bvalfromboundary('fill')
if len(img2d.shape)==2 and self.imshp[0]==1 and self.bsize==1:
img2d = img2d.reshape((1,1)+img2d.shape)
elif len(img2d.shape)==3 and self.imshp[0]==1 and self.bsize!=1:
img2d = img2d.reshape((img2d.shape[0],)+(1,)+img2d.shape[1:])
elif len(img2d.shape)==3:
img2d = img2d.reshape((1,)+(img2d.shape[0],)+img2d.shape[1:])
elif len(img2d.shape)==3 and self.imshp[0]==1 and self.bsize==1:
img2d = img2d.reshape((1,1)+img2d.shape[1:])
elif len(img2d.shape)!=4: raise Exception("bad img2d shape.")
if len(filtersflipped.shape)==3 and self.imshp[0]==1:
assert self.imshp[0]==1
filtersflipped = filtersflipped.reshape((filtersflipped.shape[0],)+(1,)+filtersflipped.shape[1:])
elif len(filtersflipped.shape)!=4: raise Exception("Bad filtersflipped shape")
img2d = img2d.reshape((self.bsize,)+ self.imshp)
filtersflipped = filtersflipped.reshape((self.nkern,self.imshp[0])+self.kshp)
for b in range(self.bsize):
for n in range(self.nkern):
zz[b,n,...].fill(0)
......@@ -140,7 +128,6 @@ class ConvOp(Op):
filters = filters[:,:,::-1,::-1]
nkern = self.imshp[0]
imshp = N.hstack((self.nkern,self.outshp))
din = ConvOp(imshp, self.kshp, nkern, self.bsize,
1,1, output_mode=mode)(gz,filters)
......@@ -420,7 +407,12 @@ def convolve2(kerns, kshp, nkern, images, imshp, bsize, step=(1,1),
kernrshp = tensor.as_tensor([nkern, nvis_dim] + list(kshp))
kerntensor = tensor.reshape(kerns, kernrshp)
print '*** convolve2 ***'
print 'imshp = ', imshp
print 'kshp = ', kshp
print 'nkern = ', nkern
print 'bsize = ', bsize
convop = ConvOp(imshp, kshp, nkern, bsize, 1, 1, output_mode=mode)
convout = convop(imtensor, kerntensor)
......
......@@ -307,13 +307,19 @@ class TestConvOp(unittest.TestCase):
bsize = 2
imgs = T.dmatrix('imgs')
kerns = T.dmatrix('kerns')
kshps = [(3,3)]
for mode in 'valid', 'full':
# 'full' mode should support kernels bigger than the input
if mode == 'full':
kshps.append((12,12))
for imshp in (5,5),(2,5,5),(2,10,10): # (12,10), (3,12,11):
visdim = 1 if len(imshp)!=3 else imshp[0]
print 'visdim = ', visdim
for kshp in (3,3),:# (6,7):
for kshp in kshps:
imgvals = N.random.random(N.hstack((bsize,imshp)))
print 'imgvals.shape = ', imgvals.shape
imgvals = imgvals.reshape(bsize,-1)
if visdim == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论