fixed grad tests

上级 9b86e097
...@@ -16,7 +16,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -16,7 +16,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg): def __init__(self, arg):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
pass pass
a = retNone(5) a = retNone(5)
try: try:
...@@ -31,7 +31,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -31,7 +31,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg): def __init__(self, arg):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return [None] return [None]
i = gof.result.ResultBase() i = gof.result.ResultBase()
a = retNone([i]) a = retNone([i])
...@@ -44,7 +44,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -44,7 +44,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg): def __init__(self, arg):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
def grad(self, inputs, gz): def grad(self, inputs, (gz, )):
return [None] return [None]
i = gof.result.ResultBase() i = gof.result.ResultBase()
...@@ -67,7 +67,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -67,7 +67,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
self.inputs = arg self.inputs = arg
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
self.tst = tst self.tst = tst
def grad(self, inputs, gz): def grad(self, inputs, (gz, )):
self.tst.fail() self.tst.fail()
i = gof.result.ResultBase() i = gof.result.ResultBase()
...@@ -81,8 +81,8 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -81,8 +81,8 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz): def grad(self, (x, ), (gz, )):
return gval return gval,
a1 = O() a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval) self.failUnless(g[a1.inputs[0]] is gval)
...@@ -94,8 +94,8 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -94,8 +94,8 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase()] self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
def grad(self, x, (gz1, gz2)): def grad(self, (x, ), (gz1, gz2)):
return gval return gval,
a1 = O() a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval) self.failUnless(g[a1.inputs[0]] is gval)
...@@ -107,7 +107,7 @@ class _test_grad_sources_inputs(unittest.TestCase): ...@@ -107,7 +107,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self): def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()] self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()] self.outputs = [gof.result.ResultBase()]
def grad(self, (x0,x1), gz): def grad(self, (x0,x1), (gz, )):
return (gval0, gval1) return (gval0, gval1)
a1 = O() a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None) g = grad_sources_inputs([(a1.outputs[0], 1)], None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论