提交 a9f805bd authored 作者: Nicolas Ballas's avatar Nicolas Ballas 提交者: Pascal Lamblin

fix codes

上级 957e3fae
...@@ -264,14 +264,12 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -264,14 +264,12 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp, d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp,
self.bsize, self.bsize,
self.border_mode, self.border_mode,
self.subsample)( self.subsample)(weights, top, bottom.shape[-2:])
weights, top, bottom.shape[-2:]) d_top = AbstractConv2d(self.imshp,
d_top = AbstractConv2d(self.imshp, self.kshp,
self.kshp, self.bsize,
self.bsize, self.border_mode,
self.border_mode, self.subsample)(bottom, weights)
self.subsample)(
bottom, weights)
d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else ()
return (d_bottom, d_top) + d_height_width return (d_bottom, d_top) + d_height_width
...@@ -282,7 +280,7 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -282,7 +280,7 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
return [[1], [1], [0], [0]] # no connection to height, width return [[1], [1], [0], [0]] # no connection to height, width
class AbstractConv2d_gradInputs(Conv2d): class AbstractConv2d_gradInputs(BaseAbstractConv2d):
"""Gradient wrt. inputs for `AbstractConv2d`. """Gradient wrt. inputs for `AbstractConv2d`.
:note: You will not want to use this directly, but rely on :note: You will not want to use this directly, but rely on
...@@ -326,11 +324,9 @@ class AbstractConv2d_gradInputs(Conv2d): ...@@ -326,11 +324,9 @@ class AbstractConv2d_gradInputs(Conv2d):
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp, d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.bsize, self.bsize,
self.border_mode, self.border_mode,
self.subsample)( self.subsample)(bottom, top, weights.shape[-2:])
bottom, top, weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.filter_shape, self.bsize, d_top = AbstractConv2d(self.imshp, self.filter_shape, self.bsize,
self.border_mode, self.subsample)( self.border_mode, self.subsample)(bottom, weights)
bottom, weights)
d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else () d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else ()
return (d_weights, d_top) + d_height_width return (d_weights, d_top) + d_height_width
...@@ -447,10 +443,10 @@ def local_conv2d_cudnn(node): ...@@ -447,10 +443,10 @@ def local_conv2d_cudnn(node):
subsample=node.op.subsample, subsample=node.op.subsample,
direction_hint='bprop inputs') direction_hint='bprop inputs')
return rval return rval
register_specialize_device()(local_conv2d_cudnn) register_specialize_device(local_conv2d_cudnn)
@local_optimizer(AbstractConv2d) @local_optimizer([AbstractConv2d])
def local_conv2d_corrmm(convop, inputs): def local_conv2d_corrmm(convop, inputs):
img, kern = node.inputs img, kern = node.inputs
...@@ -505,9 +501,9 @@ def local_conv2d_corrmm(convop, inputs): ...@@ -505,9 +501,9 @@ def local_conv2d_corrmm(convop, inputs):
rval = GpuCorrMM_gradInputs('valid', subsample)( rval = GpuCorrMM_gradInputs('valid', subsample)(
gpu_contiguous(kern), gpu_contiguous(img)) gpu_contiguous(kern), gpu_contiguous(img))
return rval return rval
register_specialize_device()(local_conv2d_corrmm) register_specialize_device(local_conv2d_corrmm)
@local_optimizer(AbstractConv2d_gradWeights) @local_optimizer([AbstractConv2d_gradWeights])
def local_conv2d_gradweight_corrmm(node): def local_conv2d_gradweight_corrmm(node):
img, topgrad, shape = node.inputs img, topgrad, shape = node.inputs
...@@ -518,9 +514,9 @@ def local_conv2d_gradweight_corrmm(node): ...@@ -518,9 +514,9 @@ def local_conv2d_gradweight_corrmm(node):
subsample=node.op.subsample)( subsample=node.op.subsample)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
return rval return rval
register_specialize_device()(local_conv2d_gradweight_corrmm) register_specialize_device(local_conv2d_gradweight_corrmm)
@local_optimizer(AbstractConv2d_gradInputs) @local_optimizer([AbstractConv2d_gradInputs])
def local_conv2d_gradinputs_corrmm(node): def local_conv2d_gradinputs_corrmm(node):
kern, topgrad, shape = node.inputs kern, topgrad, shape = node.inputs
...@@ -531,7 +527,7 @@ def local_conv2d_gradinputs_corrmm(node): ...@@ -531,7 +527,7 @@ def local_conv2d_gradinputs_corrmm(node):
subsample=node.op.subsample)( subsample=node.op.subsample)(
gpu_contiguous(kern), gpu_contiguous(topgrad), shape) gpu_contiguous(kern), gpu_contiguous(topgrad), shape)
return rval return rval
register_specialize_device()(local_conv2d_gradinputs_corrmm) register_specialize_device(local_conv2d_gradinputs_corrmm)
......
...@@ -7,7 +7,7 @@ from theano.tests import unittest_tools as utt ...@@ -7,7 +7,7 @@ from theano.tests import unittest_tools as utt
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
import theano.tensor.nnet.conv as conv_ref import theano.tensor.nnet.conv as conv_ref
import theano.tensor.nnet.conv2d as conv import theano.tensor.nnet.abstract_conv2d as conv
from theano.sandbox.cuda import float32_shared_constructor as shared from theano.sandbox.cuda import float32_shared_constructor as shared
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论