提交 f3a80dcb authored 作者: abergeron's avatar abergeron

Merge pull request #2276 from f0k/dnn-grad-grads

Add grad() to GpuDnnConvGrad* ops
...@@ -449,6 +449,21 @@ class GpuDnnConvGradW(GpuDnnConvBase): ...@@ -449,6 +449,21 @@ class GpuDnnConvGradW(GpuDnnConvBase):
path_flag = 'CUDNN_CONVOLUTION_WEIGHT_GRAD' path_flag = 'CUDNN_CONVOLUTION_WEIGHT_GRAD'
conv_op = 'cudnnConvolutionBackwardFilter' conv_op = 'cudnnConvolutionBackwardFilter'
def grad(self, inp, grads):
img, top, desc = inp
kerns, = grads
kerns = gpu_contiguous(kerns)
d_img = GpuDnnConvGradI()(kerns, top, desc)
d_top = GpuDnnConv()(img, kerns, desc)
return d_img, d_top, theano.gradient.DisconnectedType()()
def connection_pattern(self, node):
# not connected to desc
return [[1], [1], [0]]
class GpuDnnConvGradI(GpuDnnConvBase): class GpuDnnConvGradI(GpuDnnConvBase):
""" """
...@@ -466,6 +481,21 @@ class GpuDnnConvGradI(GpuDnnConvBase): ...@@ -466,6 +481,21 @@ class GpuDnnConvGradI(GpuDnnConvBase):
path_flag = 'CUDNN_CONVOLUTION_DATA_GRAD' path_flag = 'CUDNN_CONVOLUTION_DATA_GRAD'
conv_op = 'cudnnConvolutionBackwardData' conv_op = 'cudnnConvolutionBackwardData'
def grad(self, inp, grads):
kerns, top, desc = inp
img, = grads
img = gpu_contiguous(img)
d_kerns = GpuDnnConvGradW()(img, top, desc)
d_top = GpuDnnConv()(img, kerns, desc)
return d_kerns, d_top, theano.gradient.DisconnectedType()()
def connection_pattern(self, node):
# not connected to desc
return [[1], [1], [0]]
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', direction_hint=None): conv_mode='conv', direction_hint=None):
......
...@@ -1005,7 +1005,7 @@ def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsample, op): ...@@ -1005,7 +1005,7 @@ def conv_grad(mode, bs, ch, nf, rImg1, rImg2, rFlt1, rFlt2, subsample, op):
# skip if the reference implementation can't do it # skip if the reference implementation can't do it
pass pass
f = theano.function([i, k], outputs, mode=theano_mode) f = theano.function([i, k], outputs, mode=theano_mode.excluding('conv_dnn', 'conv_gemm'))
allvals = f(npy_img, npy_kern) allvals = f(npy_img, npy_kern)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论