提交 2d169da1 authored 作者: Frederic's avatar Frederic

Fix circular import

上级 ebcb4441
...@@ -83,10 +83,11 @@ class Conv3D(theano.Op): ...@@ -83,10 +83,11 @@ class Conv3D(theano.Op):
# Make sure the broadcasting pattern of the gradient is the the same # Make sure the broadcasting pattern of the gradient is the the same
# as the initial variable # as the initial variable
dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0, 0, 0, 0, :]), d, dCdH, V.shape[1:4]) dCdV = theano.tensor.nnet.convTransp3D(
W, T.zeros_like(V[0, 0, 0, 0, :]), d, dCdH, V.shape[1:4])
dCdV = T.patternbroadcast(dCdV, V.broadcastable) dCdV = T.patternbroadcast(dCdV, V.broadcastable)
WShape = W.shape WShape = W.shape
dCdW = ConvGrad3D.convGrad3D(V, d, WShape, dCdH) dCdW = theano.tensor.nnet.convGrad3D(V, d, WShape, dCdH)
dCdW = T.patternbroadcast(dCdW, W.broadcastable) dCdW = T.patternbroadcast(dCdW, W.broadcastable)
dCdb = T.sum(dCdH, axis=(0, 1, 2, 3)) dCdb = T.sum(dCdH, axis=(0, 1, 2, 3))
dCdb = T.patternbroadcast(dCdb, b.broadcastable) dCdb = T.patternbroadcast(dCdb, b.broadcastable)
...@@ -620,6 +621,3 @@ def computeH(V, W, b, d): ...@@ -620,6 +621,3 @@ def computeH(V, W, b, d):
# print 'setting H[0] += '+str(w*v)+' W['+str((j,z,k,l,m))+']='+str(w)+' V['+str((i,d[0]*x+k,d[1]*y+l,d[2]*t+m,z))+']='+str(v) # print 'setting H[0] += '+str(w*v)+' W['+str((j,z,k,l,m))+']='+str(w)+' V['+str((i,d[0]*x+k,d[1]*y+l,d[2]*t+m,z))+']='+str(v)
H[i, x, y, t, j] += w * v H[i, x, y, t, j] += w * v
return H return H
from . import ConvGrad3D
from . import ConvTransp3D
...@@ -4,8 +4,6 @@ import numpy as N ...@@ -4,8 +4,6 @@ import numpy as N
import theano import theano
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.nnet.Conv3D import conv3D
from theano.tensor.nnet.ConvTransp3D import convTransp3D
from theano.misc import strutil from theano.misc import strutil
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
...@@ -46,12 +44,12 @@ class ConvGrad3D(theano.Op): ...@@ -46,12 +44,12 @@ class ConvGrad3D(theano.Op):
dLdA, = output_gradients dLdA, = output_gradients
z = T.zeros_like(C[0, 0, 0, 0, :]) z = T.zeros_like(C[0, 0, 0, 0, :])
dLdC = convTransp3D(dLdA, z, d, B, C.shape[1:4]) dLdC = theano.tensor.nnet.convTransp3D(dLdA, z, d, B, C.shape[1:4])
# d actually does affect the outputs, so it's not disconnected # d actually does affect the outputs, so it's not disconnected
dLdd = grad_undefined(self, 1, d) dLdd = grad_undefined(self, 1, d)
# The shape of the weights doesn't affect the output elements # The shape of the weights doesn't affect the output elements
dLdWShape = DisconnectedType()() dLdWShape = DisconnectedType()()
dLdB = conv3D(C, dLdA, T.zeros_like(B[0, 0, 0, 0, :]), d) dLdB = theano.tensor.nnet.conv3D(C, dLdA, T.zeros_like(B[0, 0, 0, 0, :]), d)
return [dLdC, dLdd, dLdWShape, dLdB] return [dLdC, dLdd, dLdWShape, dLdB]
......
...@@ -8,8 +8,6 @@ from theano.tensor import basic as T ...@@ -8,8 +8,6 @@ from theano.tensor import basic as T
from theano.misc import strutil from theano.misc import strutil
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.tensor.nnet.Conv3D import conv3D
from theano.tensor.nnet.ConvGrad3D import convGrad3D
class ConvTransp3D(theano.Op): class ConvTransp3D(theano.Op):
...@@ -51,9 +49,9 @@ class ConvTransp3D(theano.Op): ...@@ -51,9 +49,9 @@ class ConvTransp3D(theano.Op):
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
W, b, d, H, RShape = inputs W, b, d, H, RShape = inputs
dCdR, = output_gradients dCdR, = output_gradients
dCdH = conv3D(dCdR, W, T.zeros_like(H[0, 0, 0, 0, :]), d) dCdH = theano.tensor.nnet.conv3D(dCdR, W, T.zeros_like(H[0, 0, 0, 0, :]), d)
WShape = W.shape WShape = W.shape
dCdW = convGrad3D(dCdR, d, WShape, H) dCdW = theano.tensor.nnet.convGrad3D(dCdR, d, WShape, H)
dCdb = T.sum(dCdR, axis=(0, 1, 2, 3)) dCdb = T.sum(dCdR, axis=(0, 1, 2, 3))
# not differentiable, since d affects the output elements # not differentiable, since d affects the output elements
dCdd = grad_undefined(self, 2, d) dCdd = grad_undefined(self, 2, d)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论