提交 67eba77c authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 ConvTransp3D

上级 b7a65bb5
...@@ -5,9 +5,10 @@ import theano ...@@ -5,9 +5,10 @@ import theano
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
class ConvTransp3D(theano.Op): class ConvTransp3D(theano.Op):
""" "Transpose" of Conv3D (Conv3D implements multiplication by an implicitly defined matrix W. This implements multiplication by its transpose) """ """ "Transpose" of Conv3D (Conv3D implements multiplication by an implicitly defined matrix W. This implements multiplication by its transpose) """
def __eq__(self,other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
...@@ -16,7 +17,7 @@ class ConvTransp3D(theano.Op): ...@@ -16,7 +17,7 @@ class ConvTransp3D(theano.Op):
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (3,)
def make_node(self, W, b, d, H, RShape = None): def make_node(self, W, b, d, H, RShape=None):
""" """
:param W: Weights, filter :param W: Weights, filter
:param b: bias, shape == (W.shape[0],) :param b: bias, shape == (W.shape[0],)
...@@ -30,7 +31,7 @@ class ConvTransp3D(theano.Op): ...@@ -30,7 +31,7 @@ class ConvTransp3D(theano.Op):
if RShape: if RShape:
RShape_ = T.as_tensor_variable(RShape) RShape_ = T.as_tensor_variable(RShape)
else: else:
RShape_ = T.as_tensor_variable([-1,-1,-1]) RShape_ = T.as_tensor_variable([-1, -1, -1])
return theano.Apply(self, inputs=[W_,b_,d_,H_, RShape_], outputs = [ T.TensorType(H_.dtype, (False,False,False,False,False))() ] ) return theano.Apply(self, inputs=[W_,b_,d_,H_, RShape_], outputs = [ T.TensorType(H_.dtype, (False,False,False,False,False))() ] )
...@@ -38,28 +39,26 @@ class ConvTransp3D(theano.Op): ...@@ -38,28 +39,26 @@ class ConvTransp3D(theano.Op):
flags = ['-Werror'] flags = ['-Werror']
return flags return flags
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
W,b,d,H,RShape = node.inputs W, b, d, H, RShape = node.inputs
W_shape, b_shape, d_shape, H_shape, RShape_shape = input_shapes W_shape, b_shape, d_shape, H_shape, RShape_shape = input_shapes
return [(H_shape[0], RShape[0], RShape[1], RShape[2], W_shape[4])] return [(H_shape[0], RShape[0], RShape[1], RShape[2], W_shape[4])]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], [True], [True], [True], [False]] return [[True], [True], [True], [True], [False]]
def grad(self,inputs, output_gradients): def grad(self, inputs, output_gradients):
W,b,d,H, RShape = inputs W, b, d, H, RShape = inputs
dCdR ,= output_gradients dCdR, = output_gradients
dCdH = conv3D( dCdR, W, T.zeros_like(H[0,0,0,0,:]), d) dCdH = conv3D(dCdR, W, T.zeros_like(H[0, 0, 0, 0, :]), d)
WShape = W.shape WShape = W.shape
dCdW = convGrad3D(dCdR,d,WShape,H) dCdW = convGrad3D(dCdR, d, WShape, H)
dCdb = T.sum(dCdR,axis=(0,1,2,3)) dCdb = T.sum(dCdR, axis=(0, 1, 2, 3))
# not differentiable, since d affects the output elements # not differentiable, since d affects the output elements
dCdd = grad_undefined(self,2,d) dCdd = grad_undefined(self, 2, d)
# disconnected, since RShape just determines the output shape # disconnected, since RShape just determines the output shape
dCdRShape = DisconnectedType()() dCdRShape = DisconnectedType()()
if 'name' in dir(dCdR) and dCdR.name is not None: if 'name' in dir(dCdR) and dCdR.name is not None:
dCdR_name = dCdR.name dCdR_name = dCdR.name
else: else:
...@@ -83,15 +82,14 @@ class ConvTransp3D(theano.Op): ...@@ -83,15 +82,14 @@ class ConvTransp3D(theano.Op):
dCdW.name = 'ConvTransp3D_dCdW.H='+H_name+',dCdR='+dCdR_name+',W='+W_name dCdW.name = 'ConvTransp3D_dCdW.H='+H_name+',dCdR='+dCdR_name+',W='+W_name
dCdb.name = 'ConvTransp3D_dCdb.H='+H_name+',dCdR='+dCdR_name+',W='+W_name+',b='+b_name dCdb.name = 'ConvTransp3D_dCdb.H='+H_name+',dCdR='+dCdR_name+',W='+W_name+',b='+b_name
dCdH.name = 'ConvTransp3D_dCdH.H='+H_name+',dCdR='+dCdR_name dCdH.name = 'ConvTransp3D_dCdH.H=' + H_name + ',dCdR=' + dCdR_name
return [ dCdW, dCdb, dCdd, dCdH, dCdRShape ]
return [dCdW, dCdb, dCdd, dCdH, dCdRShape]
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
W, b, d, H, RShape = inputs W, b, d, H, RShape = inputs
# print "\t\t\t\tConvTransp3D python code" # print "\t\t\t\tConvTransp3D python code"
output_storage[0][0] = computeR(W,b,d,H,RShape) output_storage[0][0] = computeR(W, b, d, H, RShape)
def c_code(self, node, nodename, inputs, outputs, sub): def c_code(self, node, nodename, inputs, outputs, sub):
W, b, d, H, RShape = inputs W, b, d, H, RShape = inputs
...@@ -328,33 +326,35 @@ class ConvTransp3D(theano.Op): ...@@ -328,33 +326,35 @@ class ConvTransp3D(theano.Op):
///////////// < /code generated by ConvTransp3D > ///////////// < /code generated by ConvTransp3D >
""" """
return strutil.renderString(codeSource,locals()) return strutil.renderString(codeSource, locals())
convTransp3D = ConvTransp3D() convTransp3D = ConvTransp3D()
#If the input size wasn't a multiple of D we may need to cause some automatic padding to get the right size of reconstruction #If the input size wasn't a multiple of D we may need to cause some automatic padding to get the right size of reconstruction
def computeR(W,b,d,H,Rshape = None):
def computeR(W, b, d, H, Rshape=None):
assert len(W.shape) == 5 assert len(W.shape) == 5
assert len(H.shape) == 5 assert len(H.shape) == 5
assert len(b.shape) == 1 assert len(b.shape) == 1
assert len(d) == 3 assert len(d) == 3
outputChannels, filterHeight, filterWidth, filterDur,
outputChannels, filterHeight, filterWidth, filterDur, inputChannels = W.shape inputChannels = W.shape
batchSize, outputHeight, outputWidth, outputDur, outputChannelsAgain = H.shape batchSize, outputHeight, outputWidth, outputDur,
outputChannelsAgain = H.shape
assert outputChannelsAgain == outputChannels assert outputChannelsAgain == outputChannels
assert b.shape[0] == inputChannels assert b.shape[0] == inputChannels
dr, dc, dt = d
dr,dc,dt = d
assert dr > 0 assert dr > 0
assert dc > 0 assert dc > 0
assert dt > 0 assert dt > 0
videoHeight = (outputHeight-1) * dr + filterHeight videoHeight = (outputHeight - 1) * dr + filterHeight
videoWidth = (outputWidth-1) * dc + filterWidth videoWidth = (outputWidth - 1) * dc + filterWidth
videoDur = (outputDur-1) * dt + filterDur videoDur = (outputDur - 1) * dt + filterDur
if Rshape is not None and Rshape[0] != -1: if Rshape is not None and Rshape[0] != -1:
if Rshape[0] < videoHeight: if Rshape[0] < videoHeight:
...@@ -371,24 +371,27 @@ def computeR(W,b,d,H,Rshape = None): ...@@ -371,24 +371,27 @@ def computeR(W,b,d,H,Rshape = None):
#print "video size: "+str((videoHeight, videoWidth, videoDur)) #print "video size: "+str((videoHeight, videoWidth, videoDur))
R = N.zeros( (batchSize, videoHeight, R = N.zeros((batchSize, videoHeight,
videoWidth, videoDur, inputChannels ) , dtype=H.dtype) videoWidth, videoDur, inputChannels), dtype=H.dtype)
#R[i,j,r,c,t] = b_j + sum_{rc,rk | d \circ rc + rk = r} sum_{cc,ck | ...} sum_{tc,tk | ...} sum_k W[k, j, rk, ck, tk] * H[i,k,rc,cc,tc] #R[i,j,r,c,t] = b_j + sum_{rc,rk | d \circ rc + rk = r} sum_{cc,ck | ...} sum_{tc,tk | ...} sum_k W[k, j, rk, ck, tk] * H[i,k,rc,cc,tc]
for i in xrange(0,batchSize): for i in xrange(0, batchSize):
#print '\texample '+str(i+1)+'/'+str(batchSize) #print '\texample '+str(i+1)+'/'+str(batchSize)
for j in xrange(0,inputChannels): for j in xrange(0, inputChannels):
#print '\t\tfeature map '+str(j+1)+'/'+str(inputChannels) #print '\t\tfeature map '+str(j+1)+'/'+str(inputChannels)
for r in xrange(0,videoHeight): for r in xrange(0, videoHeight):
#print '\t\t\trow '+str(r+1)+'/'+str(videoHeight) #print '\t\t\trow '+str(r+1)+'/'+str(videoHeight)
for c in xrange(0,videoWidth): for c in xrange(0, videoWidth):
for t in xrange(0,videoDur): for t in xrange(0, videoDur):
R[i,r,c,t,j] = b[j] R[i, r, c, t, j] = b[j]
ftc = max([0, int(N.ceil(float(t-filterDur +1 )/float(dt))) ]) ftc = max([0, int(N.ceil(
fcc = max([0, int(N.ceil(float(c-filterWidth +1)/float(dc))) ]) float(t - filterDur + 1) / float(dt)))])
fcc = max([0, int(N.ceil(
float(c - filterWidth + 1) / float(dc)))])
rc = max([0, int(N.ceil(float(r-filterHeight+1)/float(dr))) ]) rc = max([0, int(N.ceil(
float(r - filterHeight + 1) / float(dr)))])
while rc < outputHeight: while rc < outputHeight:
rk = r - rc * dr rk = r - rc * dr
if rk < 0: if rk < 0:
...@@ -406,20 +409,21 @@ def computeR(W,b,d,H,Rshape = None): ...@@ -406,20 +409,21 @@ def computeR(W,b,d,H,Rshape = None):
if tk < 0: if tk < 0:
break break
R[i,r,c,t,j] += N.dot(W[:,rk,ck,tk,j], H[i,rc,cc,tc,:] ) R[
i,r,c,t,j] += N.dot(W[:,rk,ck,tk,j], H[i,rc,cc,tc,:] )
tc += 1 tc += 1
"" #close loop over tc "" # close loop over tc
cc += 1 cc += 1
"" #close loop over cc "" # close loop over cc
rc += 1 rc += 1
"" #close loop over rc "" # close loop over rc
"" #close loop over t "" # close loop over t
"" #close loop over c "" # close loop over c
"" #close loop over r "" # close loop over r
"" #close loop over j "" # close loop over j
"" #close loop over i "" # close loop over i
return R return R
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论