提交 cf5653fe authored 作者: fsavard's avatar fsavard

Added very simple grad for neighbours op, plus one unit test for a basic case for it.

上级 925a4eb6
...@@ -46,8 +46,8 @@ class Images2Neibs(Op): ...@@ -46,8 +46,8 @@ class Images2Neibs(Op):
return Apply(self, [ten4, neib_shape,neib_step], [T.matrix(dtype=ten4.type.dtype)]) return Apply(self, [ten4, neib_shape,neib_step], [T.matrix(dtype=ten4.type.dtype)])
def grad(self, (pvals, unis), (gz,)): def grad(self, (x, neib_shape, neib_step), (gz,)):
return [None, None] return [neibs2images(gz, neib_shape, x.shape), None, None]
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (3,)
......
...@@ -328,8 +328,56 @@ def speed_neibs_wrap_centered(): ...@@ -328,8 +328,56 @@ def speed_neibs_wrap_centered():
for i in range(1000): for i in range(1000):
f() f()
def test_neibs_grad():
shape = (2,3,4,4)
images = T.shared(numpy.arange(numpy.prod(shape), dtype='float32').reshape(shape))
cost = T.sum(T.sqr(images2neibs(images, (2,2))), axis=[0,1])
grad = T.grad(cost, images)
f = theano.function([], [cost, grad], mode=mode_without_gpu)
got = f()
should_get = [numpy.asarray(290320.0, dtype=numpy.float32),
numpy.asarray([[[[ 0., 2., 4., 6.],
[ 8., 10., 12., 14.],
[ 16., 18., 20., 22.],
[ 24., 26., 28., 30.]],
[[ 32., 34., 36., 38.],
[ 40., 42., 44., 46.],
[ 48., 50., 52., 54.],
[ 56., 58., 60., 62.]],
[[ 64., 66., 68., 70.],
[ 72., 74., 76., 78.],
[ 80., 82., 84., 86.],
[ 88., 90., 92., 94.]]],
[[[ 96., 98., 100., 102.],
[ 104., 106., 108., 110.],
[ 112., 114., 116., 118.],
[ 120., 122., 124., 126.]],
[[ 128., 130., 132., 134.],
[ 136., 138., 140., 142.],
[ 144., 146., 148., 150.],
[ 152., 154., 156., 158.]],
[[ 160., 162., 164., 166.],
[ 168., 170., 172., 174.],
[ 176., 178., 180., 182.],
[ 184., 186., 188., 190.]]]], dtype=numpy.float32)]
assert numpy.allclose(got[0], should_get[0])
assert numpy.allclose(got[1], should_get[1])
if __name__ == '__main__': if __name__ == '__main__':
test_neibs_gpu() #test_neibs_gpu()
test_neibs() #test_neibs()
test_neibs_grad()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论