提交 709a3520 authored 作者: f0k's avatar f0k

GpuCorrMM: Added gradients of gradients

上级 c1c7efe7
...@@ -877,6 +877,23 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM): ...@@ -877,6 +877,23 @@ class GpuCorrMM_gradWeights(BaseGpuCorrMM):
direction = "backprop weights" direction = "backprop weights"
return super(GpuCorrMM_gradWeights, self).c_code_helper(bottom, weights, top, direction, sub, height, width) return super(GpuCorrMM_gradWeights, self).c_code_helper(bottom, weights, top, direction, sub, height, width)
def grad(self, inp, grads):
bottom, top = inp[:2]
weights, = grads
weights = gpu_contiguous(weights)
d_bottom = GpuCorrMM_gradInputs(self.border_mode, self.subsample, self.pad)(
weights, top, bottom.shape[-2:])
d_top = GpuCorrMM(self.border_mode, self.subsample, self.pad)(
bottom, weights)
d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else ()
return (d_bottom, d_top) + d_height_width
def connection_pattern(self, node):
if node.nin == 2:
return [[1], [1]]
else:
return [[1], [1], [0], [0]] # no connection to height, width
class GpuCorrMM_gradInputs(BaseGpuCorrMM): class GpuCorrMM_gradInputs(BaseGpuCorrMM):
"""Gradient wrt. inputs for `GpuCorrMM`. """Gradient wrt. inputs for `GpuCorrMM`.
...@@ -911,6 +928,23 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM): ...@@ -911,6 +928,23 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
direction = "backprop inputs" direction = "backprop inputs"
return super(GpuCorrMM_gradInputs, self).c_code_helper(bottom, weights, top, direction, sub, height, width) return super(GpuCorrMM_gradInputs, self).c_code_helper(bottom, weights, top, direction, sub, height, width)
def grad(self, inp, grads):
weights, top = inp[:2]
bottom, = grads
bottom = gpu_contiguous(bottom)
d_weights = GpuCorrMM_gradWeights(self.border_mode, self.subsample, self.pad)(
bottom, top, weights.shape[-2:])
d_top = GpuCorrMM(self.border_mode, self.subsample, self.pad)(
bottom, weights)
d_height_width = (theano.gradient.DisconnectedType()(),) * 2 if len(inp) == 4 else ()
return (d_weights, d_top) + d_height_width
def connection_pattern(self, node):
if node.nin == 2:
return [[1], [1]]
else:
return [[1], [1], [0], [0]] # no connection to height, width
## ##
# Not really a BLAS operation, but whatever. # Not really a BLAS operation, but whatever.
......
...@@ -895,13 +895,13 @@ def test_gemm_directly(): ...@@ -895,13 +895,13 @@ def test_gemm_directly():
def test_gemm_grads(): def test_gemm_grads():
for mode in 'valid', 'full': for mode in 'valid', 'full':
for bs in range(1, 5): for bs in [1, 4, 5]:
for ch in range(1,4): for ch in range(1,4):
for nf in range(1,4): for nf in range(1,4):
for rImg1 in range(5, 9): for rImg1 in [2, 5, 8]:
for rImg2 in range(5, 9): for rImg2 in [2, 5, 8]:
for rFlt1 in range(2, 4): for rFlt1 in [1, 2]:
for rFlt2 in range(2, 4): for rFlt2 in [1, 2]:
for subsx in range(1, 3): for subsx in range(1, 3):
for subsy in range(1, 3): for subsy in range(1, 3):
ishape = (bs, ch, rImg1, rImg2) ishape = (bs, ch, rImg1, rImg2)
...@@ -920,19 +920,33 @@ def test_gemm_grads(): ...@@ -920,19 +920,33 @@ def test_gemm_grads():
'valid', subsample, pad)(i, k) 'valid', subsample, pad)(i, k)
conv_op = tensor.nnet.conv2d(i, k[:,:,::-1,::-1], conv_op = tensor.nnet.conv2d(i, k[:,:,::-1,::-1],
ishape, kshape, mode, subsample) ishape, kshape, mode, subsample)
corr_op_di = theano.grad(corr_op.sum(), i)
f = theano.function([i, k], conv_op_di = theano.grad(conv_op.sum(), i)
[corr_op, corr_op_dk = theano.grad(corr_op.sum(), k)
theano.grad(corr_op.sum(), i), conv_op_dk = theano.grad(conv_op.sum(), k)
theano.grad(corr_op.sum(), k), outputs = [corr_op, conv_op,
conv_op, corr_op_di, conv_op_di,
theano.grad(conv_op.sum(), i), corr_op_dk, conv_op_dk]
theano.grad(conv_op.sum(), k)], try:
mode=theano_mode) conv_op_dik = theano.grad(conv_op_di.sum(), k)
conv_op_dki = theano.grad(conv_op_dk.sum(), i)
except Exception:
# skip if the reference implementation can't do it
print ".",
else:
corr_op_dik = theano.grad(corr_op_di.sum(), k)
corr_op_dki = theano.grad(corr_op_dk.sum(), i)
outputs.extend([corr_op_dik, conv_op_dik,
corr_op_dki, conv_op_dki])
print ":",
f = theano.function([i, k], outputs, mode=theano_mode)
allvals = f(npy_img, npy_kern) allvals = f(npy_img, npy_kern)
for a, b, p in zip(allvals[:3], allvals[3:], ('fprop', 'bprop img', 'bprop kern')): for a, b, p in zip(allvals[::2], allvals[1::2],
('top', 'dtop/dbottom', 'dtop/dweight',
'dtop/dbottom/dweight', 'dtop/dweight/dbottom')):
if (a.shape != b.shape) or not numpy.allclose(a, b, rtol=1e-4): if (a.shape != b.shape) or not numpy.allclose(a, b, rtol=1e-4):
print "Test failed for", p print "Test failed for", p
print "mode: ", mode print "mode: ", mode
...@@ -940,6 +954,7 @@ def test_gemm_grads(): ...@@ -940,6 +954,7 @@ def test_gemm_grads():
print "kshape: ", kshape print "kshape: ", kshape
print "subsample: ", subsample print "subsample: ", subsample
assert False assert False
sys.stdout.flush()
def benchmark(): def benchmark():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论