提交 67eba77c authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 ConvTransp3D

上级 b7a65bb5
......@@ -5,9 +5,10 @@ import theano
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
class ConvTransp3D(theano.Op):
""" "Transpose" of Conv3D (Conv3D implements multiplication by an implicitly defined matrix W. This implements multiplication by its transpose) """
def __eq__(self,other):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
......@@ -16,7 +17,7 @@ class ConvTransp3D(theano.Op):
def c_code_cache_version(self):
return (3,)
def make_node(self, W, b, d, H, RShape = None):
def make_node(self, W, b, d, H, RShape=None):
"""
:param W: Weights, filter
:param b: bias, shape == (W.shape[0],)
......@@ -30,7 +31,7 @@ class ConvTransp3D(theano.Op):
if RShape:
RShape_ = T.as_tensor_variable(RShape)
else:
RShape_ = T.as_tensor_variable([-1,-1,-1])
RShape_ = T.as_tensor_variable([-1, -1, -1])
return theano.Apply(self, inputs=[W_,b_,d_,H_, RShape_], outputs = [ T.TensorType(H_.dtype, (False,False,False,False,False))() ] )
......@@ -38,28 +39,26 @@ class ConvTransp3D(theano.Op):
flags = ['-Werror']
return flags
def infer_shape(self, node, input_shapes):
W,b,d,H,RShape = node.inputs
W, b, d, H, RShape = node.inputs
W_shape, b_shape, d_shape, H_shape, RShape_shape = input_shapes
return [(H_shape[0], RShape[0], RShape[1], RShape[2], W_shape[4])]
def connection_pattern(self, node):
return [[True], [True], [True], [True], [False]]
def grad(self,inputs, output_gradients):
W,b,d,H, RShape = inputs
dCdR ,= output_gradients
dCdH = conv3D( dCdR, W, T.zeros_like(H[0,0,0,0,:]), d)
def grad(self, inputs, output_gradients):
W, b, d, H, RShape = inputs
dCdR, = output_gradients
dCdH = conv3D(dCdR, W, T.zeros_like(H[0, 0, 0, 0, :]), d)
WShape = W.shape
dCdW = convGrad3D(dCdR,d,WShape,H)
dCdb = T.sum(dCdR,axis=(0,1,2,3))
dCdW = convGrad3D(dCdR, d, WShape, H)
dCdb = T.sum(dCdR, axis=(0, 1, 2, 3))
# not differentiable, since d affects the output elements
dCdd = grad_undefined(self,2,d)
dCdd = grad_undefined(self, 2, d)
# disconnected, since RShape just determines the output shape
dCdRShape = DisconnectedType()()
if 'name' in dir(dCdR) and dCdR.name is not None:
dCdR_name = dCdR.name
else:
......@@ -83,15 +82,14 @@ class ConvTransp3D(theano.Op):
dCdW.name = 'ConvTransp3D_dCdW.H='+H_name+',dCdR='+dCdR_name+',W='+W_name
dCdb.name = 'ConvTransp3D_dCdb.H='+H_name+',dCdR='+dCdR_name+',W='+W_name+',b='+b_name
dCdH.name = 'ConvTransp3D_dCdH.H='+H_name+',dCdR='+dCdR_name
return [ dCdW, dCdb, dCdd, dCdH, dCdRShape ]
dCdH.name = 'ConvTransp3D_dCdH.H=' + H_name + ',dCdR=' + dCdR_name
return [dCdW, dCdb, dCdd, dCdH, dCdRShape]
def perform(self, node, inputs, output_storage):
W, b, d, H, RShape = inputs
# print "\t\t\t\tConvTransp3D python code"
output_storage[0][0] = computeR(W,b,d,H,RShape)
output_storage[0][0] = computeR(W, b, d, H, RShape)
def c_code(self, node, nodename, inputs, outputs, sub):
W, b, d, H, RShape = inputs
......@@ -328,33 +326,35 @@ class ConvTransp3D(theano.Op):
///////////// < /code generated by ConvTransp3D >
"""
return strutil.renderString(codeSource,locals())
return strutil.renderString(codeSource, locals())
convTransp3D = ConvTransp3D()
#If the input size wasn't a multiple of D we may need to cause some automatic padding to get the right size of reconstruction
def computeR(W,b,d,H,Rshape = None):
def computeR(W, b, d, H, Rshape=None):
assert len(W.shape) == 5
assert len(H.shape) == 5
assert len(b.shape) == 1
assert len(d) == 3
outputChannels, filterHeight, filterWidth, filterDur, inputChannels = W.shape
batchSize, outputHeight, outputWidth, outputDur, outputChannelsAgain = H.shape
outputChannels, filterHeight, filterWidth, filterDur,
inputChannels = W.shape
batchSize, outputHeight, outputWidth, outputDur,
outputChannelsAgain = H.shape
assert outputChannelsAgain == outputChannels
assert b.shape[0] == inputChannels
dr,dc,dt = d
dr, dc, dt = d
assert dr > 0
assert dc > 0
assert dt > 0
videoHeight = (outputHeight-1) * dr + filterHeight
videoWidth = (outputWidth-1) * dc + filterWidth
videoDur = (outputDur-1) * dt + filterDur
videoHeight = (outputHeight - 1) * dr + filterHeight
videoWidth = (outputWidth - 1) * dc + filterWidth
videoDur = (outputDur - 1) * dt + filterDur
if Rshape is not None and Rshape[0] != -1:
if Rshape[0] < videoHeight:
......@@ -371,24 +371,27 @@ def computeR(W,b,d,H,Rshape = None):
#print "video size: "+str((videoHeight, videoWidth, videoDur))
R = N.zeros( (batchSize, videoHeight,
videoWidth, videoDur, inputChannels ) , dtype=H.dtype)
R = N.zeros((batchSize, videoHeight,
videoWidth, videoDur, inputChannels), dtype=H.dtype)
#R[i,j,r,c,t] = b_j + sum_{rc,rk | d \circ rc + rk = r} sum_{cc,ck | ...} sum_{tc,tk | ...} sum_k W[k, j, rk, ck, tk] * H[i,k,rc,cc,tc]
for i in xrange(0,batchSize):
for i in xrange(0, batchSize):
#print '\texample '+str(i+1)+'/'+str(batchSize)
for j in xrange(0,inputChannels):
for j in xrange(0, inputChannels):
#print '\t\tfeature map '+str(j+1)+'/'+str(inputChannels)
for r in xrange(0,videoHeight):
for r in xrange(0, videoHeight):
#print '\t\t\trow '+str(r+1)+'/'+str(videoHeight)
for c in xrange(0,videoWidth):
for t in xrange(0,videoDur):
R[i,r,c,t,j] = b[j]
for c in xrange(0, videoWidth):
for t in xrange(0, videoDur):
R[i, r, c, t, j] = b[j]
ftc = max([0, int(N.ceil(float(t-filterDur +1 )/float(dt))) ])
fcc = max([0, int(N.ceil(float(c-filterWidth +1)/float(dc))) ])
ftc = max([0, int(N.ceil(
float(t - filterDur + 1) / float(dt)))])
fcc = max([0, int(N.ceil(
float(c - filterWidth + 1) / float(dc)))])
rc = max([0, int(N.ceil(float(r-filterHeight+1)/float(dr))) ])
rc = max([0, int(N.ceil(
float(r - filterHeight + 1) / float(dr)))])
while rc < outputHeight:
rk = r - rc * dr
if rk < 0:
......@@ -406,20 +409,21 @@ def computeR(W,b,d,H,Rshape = None):
if tk < 0:
break
R[i,r,c,t,j] += N.dot(W[:,rk,ck,tk,j], H[i,rc,cc,tc,:] )
R[
i,r,c,t,j] += N.dot(W[:,rk,ck,tk,j], H[i,rc,cc,tc,:] )
tc += 1
"" #close loop over tc
"" # close loop over tc
cc += 1
"" #close loop over cc
"" # close loop over cc
rc += 1
"" #close loop over rc
"" #close loop over t
"" #close loop over c
"" #close loop over r
"" #close loop over j
"" #close loop over i
"" # close loop over rc
"" # close loop over t
"" # close loop over c
"" # close loop over r
"" # close loop over j
"" # close loop over i
return R
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论