提交 9e2d8a33 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix errors in the gradient definition

上级 4acf0a40
...@@ -1347,9 +1347,9 @@ class GpuCorr3dMM(BaseGpuCorr3dMM): ...@@ -1347,9 +1347,9 @@ class GpuCorr3dMM(BaseGpuCorr3dMM):
top, = grads top, = grads
top = gpu_contiguous(top) top = gpu_contiguous(top)
d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, self.subsample, self.pad)( d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, self.subsample, self.pad)(
weights, top, bottom.shape[-2:]) weights, top, bottom.shape[-3:])
d_weights = GpuCorr3dMM_gradWeights(self.border_mode, self.subsample, self.pad)( d_weights = GpuCorr3dMM_gradWeights(self.border_mode, self.subsample, self.pad)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-3:])
return d_bottom, d_weights return d_bottom, d_weights
...@@ -1394,7 +1394,7 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM): ...@@ -1394,7 +1394,7 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM):
bottom, top = inp[:2] bottom, top = inp[:2]
weights, = grads weights, = grads
weights = gpu_contiguous(weights) weights = gpu_contiguous(weights)
d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, self.subsample, self.pad)(weights, top, bottom.shape[-2:]) d_bottom = GpuCorr3dMM_gradInputs(self.border_mode, self.subsample, self.pad)(weights, top, bottom.shape[-3:])
d_top = GpuCorr3dMM(self.border_mode, self.subsample, self.pad)( d_top = GpuCorr3dMM(self.border_mode, self.subsample, self.pad)(
bottom, weights) bottom, weights)
d_height_width_depth = (theano.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () d_height_width_depth = (theano.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else ()
...@@ -1404,7 +1404,7 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM): ...@@ -1404,7 +1404,7 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM):
if node.nin == 2: if node.nin == 2:
return [[1], [1]] return [[1], [1]]
else: else:
return [[1], [1], [0], [0], [0]] # no connection to height, width return [[1], [1], [0], [0], [0]] # no connection to height, width, depth
class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
"""Gradient wrt. inputs for `GpuCorr3dMM`. """Gradient wrt. inputs for `GpuCorr3dMM`.
...@@ -1444,7 +1444,7 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): ...@@ -1444,7 +1444,7 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
bottom, = grads bottom, = grads
bottom = gpu_contiguous(bottom) bottom = gpu_contiguous(bottom)
d_weights = GpuCorr3dMM_gradWeights(self.border_mode, self.subsample, self.pad)( d_weights = GpuCorr3dMM_gradWeights(self.border_mode, self.subsample, self.pad)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-3:])
d_top = GpuCorr3dMM(self.border_mode, self.subsample, self.pad)( d_top = GpuCorr3dMM(self.border_mode, self.subsample, self.pad)(
bottom, weights) bottom, weights)
d_height_width_depth = (theano.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else () d_height_width_depth = (theano.gradient.DisconnectedType()(),) * 3 if len(inp) == 5 else ()
...@@ -1454,7 +1454,7 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM): ...@@ -1454,7 +1454,7 @@ class GpuCorr3dMM_gradInputs(BaseGpuCorr3dMM):
if node.nin == 2: if node.nin == 2:
return [[1], [1], [1]] return [[1], [1], [1]]
else: else:
return [[1], [1], [0], [0], [0]] # no connection to height, width return [[1], [1], [0], [0], [0]] # no connection to height, width, depth
## ##
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论