added grad tests

上级 fa35dffc
...@@ -217,6 +217,9 @@ class T_transpose(unittest.TestCase): ...@@ -217,6 +217,9 @@ class T_transpose(unittest.TestCase):
#test aliasing #test aliasing
tval += 55.0 tval += 55.0
self.failUnless(n.data[0,0,0] == 56.0) self.failUnless(n.data[0,0,0] == 56.0)
def test_grad(self):
verify_grad(self, TransposeInplace, [numpy.random.rand(2, 3)])
verify_grad(self, TransposeInplace, [numpy.ones(3)])
class T_subtensor(unittest.TestCase): class T_subtensor(unittest.TestCase):
def test0_err_invalid(self): def test0_err_invalid(self):
...@@ -648,6 +651,9 @@ class t_dot(unittest.TestCase): ...@@ -648,6 +651,9 @@ class t_dot(unittest.TestCase):
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7)) def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8)) def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
def test_grad(self):
verify_grad(self, Dot, [self.rand(2,3), self.rand(3,2)])
class t_gemm(unittest.TestCase): class t_gemm(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(44) numpy.random.seed(44)
......
...@@ -239,7 +239,7 @@ class TransposeInplace(_Op, Viewer): ...@@ -239,7 +239,7 @@ class TransposeInplace(_Op, Viewer):
return [rval] return [rval]
def impl(self, x): def impl(self, x):
return x.T #numpy's transpose return x.T #numpy's transpose
def grad(self, (x,), (gz),): def grad(self, (x,), (gz,)):
return transpose(gz), return transpose(gz),
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论