提交 024ec750 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test for the gradients of convolution and fix them.

上级 6e500132
...@@ -4,7 +4,7 @@ import numpy ...@@ -4,7 +4,7 @@ import numpy
import theano import theano
from theano import Apply, gof, tensor, config, Variable from theano import Apply, gof, tensor, config, Variable
from theano.scalar import as_scalar, constant from theano.scalar import as_scalar, constant
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType, grad_not_implemented
from theano.gof import Optimizer, local_optimizer, COp from theano.gof import Optimizer, local_optimizer, COp
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.compat import PY3 from theano.compat import PY3
...@@ -434,13 +434,13 @@ class GpuDnnConv(DnnBase, COp): ...@@ -434,13 +434,13 @@ class GpuDnnConv(DnnBase, COp):
d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc) d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc)
d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc) d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc)
d_alpha = grad_not_implemented(self, 4, alpha)
return [d_img, d_kerns, output.zeros_like(), return [d_img, d_kerns, top * alpha, DisconnectedType()(), d_alpha]
DisconnectedType()(), DisconnectedType()()]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc, alpha # not connected to desc
return [[1], [1], [1], [0], [0]] return [[1], [1], [1], [0], [1]]
@staticmethod @staticmethod
def get_out_shape(ishape, kshape, border_mode, subsample): def get_out_shape(ishape, kshape, border_mode, subsample):
...@@ -509,13 +509,13 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -509,13 +509,13 @@ class GpuDnnConvGradW(DnnBase, COp):
d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc) d_img = GpuDnnConvGradI()(kerns, top, img.zeros_like(), desc)
d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc) d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc)
d_alpha = grad_not_implemented(self, 4, alpha)
return (d_img, d_top, output.zeros_like(), return (d_img, d_top, kerns * alpha, DisconnectedType()(), d_alpha)
DisconnectedType()(), DiconnnectedType()())
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc, alpha # not connected to desc
return [[1], [1], [1], [0], [0]] return [[1], [1], [1], [0], [1]]
def get_op_params(self): def get_op_params(self):
if self.inplace: if self.inplace:
...@@ -573,12 +573,13 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -573,12 +573,13 @@ class GpuDnnConvGradI(DnnBase, COp):
d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc) d_kerns = GpuDnnConvGradW()(img, top, kerns.zeros_like(), desc)
d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc) d_top = GpuDnnConv()(img, kerns, top.zeros_like(), desc)
return (d_kerns, d_top, output.zeros_like(), d_alpha = grad_not_implemented(self, 4, alpha)
DisconnectedType()(), DisconnectedType()())
return (d_kerns, d_top, img * alpha, DisconnectedType()(), d_alpha)
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc, alpha # not connected to desc
return [[1], [1], [1], [0], [0]] return [[1], [1], [1], [0], [1]]
def get_op_params(self): def get_op_params(self):
if self.inplace: if self.inplace:
......
...@@ -492,6 +492,41 @@ def test_dnn_conv_merge(): ...@@ -492,6 +492,41 @@ def test_dnn_conv_merge():
utt.assert_allclose(v1, v2) utt.assert_allclose(v1, v2)
def test_dnn_conv_grad():
if dnn.version() == -1:
raise SkipTest('alpha != 1.0 not supported in cudnn v1')
b = 1
c = 4
f = 3
ih = 2
iw = 8
kh = 2
kw = 2
img_val = numpy.random.random((b, c, ih, iw)).astype('float32')
kern_val = numpy.random.random((f, c, kh, kw)).astype('float32')
out_val = numpy.random.random((b, f, ih-kw+1, iw-kw+1)).astype('float32')
def dconv(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConv()(img, kern, out, desc)
def dconvi(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConvGradI()(kern, out, img, desc)
def dconvw(img, kern, out):
desc = dnn.GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='conv')(img.shape, kern.shape)
return dnn.GpuDnnConvGradW()(img, out, kern, desc)
utt.verify_grad(dconv, [img_val, kern_val, out_val])
utt.verify_grad(dconvi, [img_val, kern_val, out_val])
utt.verify_grad(dconvw, [img_val, kern_val, out_val])
def test_version(): def test_version():
if not cuda.dnn.dnn_available(): if not cuda.dnn.dnn_available():
raise SkipTest(cuda.dnn.dnn_available.msg) raise SkipTest(cuda.dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论