提交 04d73513 authored 作者: abergeron's avatar abergeron

Merge pull request #16 from lamblin/abergeron-fix_gpuadvsub1

Add tests for gradient in broadcastable cases
...@@ -500,9 +500,25 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -500,9 +500,25 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.ignore_topo)] self.ignore_topo)]
assert len(topo_) == 1 assert len(topo_) == 1
self.assertTrue(isinstance(topo_[0].op, self.adv_sub1)) self.assertTrue(isinstance(topo_[0].op, self.adv_sub1))
self.assertTrue(numpy.allclose(f([0]), ones[0] * 5)) f_0 = f([0])
self.assertTrue(f_0.shape == (1, 3))
self.assertTrue(numpy.allclose(f_0, ones[0] * 5))
f_00 = f([0, 0])
self.assertTrue(f_00.shape == (2, 3))
self.assertTrue(numpy.allclose(f_00, 5))
self.assertRaises(IndexError, f, [0, 1]) self.assertRaises(IndexError, f, [0, 1])
# Test the gradient
c = t.sum()
gn = theano.grad(c, n)
g = self.function([idx], gn, op=self.adv_incsub1)
g_0 = g([0])
self.assertTrue(g_0.shape == (1, 3))
self.assertTrue(numpy.allclose(g_0, 1))
g_00 = g([0, 0])
self.assertTrue(g_00.shape == (1, 3))
self.assertTrue(numpy.allclose(g_00, 2))
def test_adv_sub1_idx_broadcast(self): def test_adv_sub1_idx_broadcast(self):
# The idx can be a broadcastable vector. # The idx can be a broadcastable vector.
ones = numpy.ones((4, 3), dtype=self.dtype) ones = numpy.ones((4, 3), dtype=self.dtype)
...@@ -518,7 +534,18 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -518,7 +534,18 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.ignore_topo)] self.ignore_topo)]
assert len(topo_) == 1 assert len(topo_) == 1
self.assertTrue(isinstance(topo_[0].op, self.adv_sub1)) self.assertTrue(isinstance(topo_[0].op, self.adv_sub1))
self.assertTrue(numpy.allclose(f([0]), ones[0] * 5)) f_0 = f([0])
self.assertTrue(f_0.shape == (1, 3))
self.assertTrue(numpy.allclose(f_0, 5))
# Test the gradient
c = t.sum()
gn = theano.grad(c, n)
g = self.function([idx], gn, op=self.adv_incsub1)
g_0 = g([0])
self.assertTrue(g_0.shape == (4, 3))
self.assertTrue(numpy.allclose(g_0[0], 1))
self.assertTrue(numpy.allclose(g_0[1:], 0))
@attr('slow') @attr('slow')
def test_shape_i_const(self): def test_shape_i_const(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论