fixed grad tests

上级 9b86e097
......@@ -16,7 +16,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
def grad(self, (x, ), (gz, )):
pass
a = retNone(5)
try:
......@@ -31,7 +31,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
def grad(self, (x, ), (gz, )):
return [None]
i = gof.result.ResultBase()
a = retNone([i])
......@@ -44,7 +44,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self, arg):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
def grad(self, inputs, gz):
def grad(self, inputs, (gz, )):
return [None]
i = gof.result.ResultBase()
......@@ -67,7 +67,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
self.inputs = arg
self.outputs = [gof.result.ResultBase()]
self.tst = tst
def grad(self, inputs, gz):
def grad(self, inputs, (gz, )):
self.tst.fail()
i = gof.result.ResultBase()
......@@ -81,8 +81,8 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, x, gz):
return gval
def grad(self, (x, ), (gz, )):
return gval,
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
......@@ -94,8 +94,8 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self):
self.inputs = [gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase(),gof.result.ResultBase()]
def grad(self, x, (gz1, gz2)):
return gval
def grad(self, (x, ), (gz1, gz2)):
return gval,
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
self.failUnless(g[a1.inputs[0]] is gval)
......@@ -107,7 +107,7 @@ class _test_grad_sources_inputs(unittest.TestCase):
def __init__(self):
self.inputs = [gof.result.ResultBase(),gof.result.ResultBase()]
self.outputs = [gof.result.ResultBase()]
def grad(self, (x0,x1), gz):
def grad(self, (x0,x1), (gz, )):
return (gval0, gval1)
a1 = O()
g = grad_sources_inputs([(a1.outputs[0], 1)], None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论