提交 bf226158 authored 作者: James Bergstra's avatar James Bergstra

fix #158

上级 4fc8b90f
...@@ -1692,16 +1692,24 @@ class _test_grad(unittest.TestCase): ...@@ -1692,16 +1692,24 @@ class _test_grad(unittest.TestCase):
"""grad: Test returning a single None from grad""" """grad: Test returning a single None from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
self.failUnless(None is grad(a1.outputs[0], a1.outputs[1])) g = grad(a1.outputs[0], a1.outputs[1])
self.failUnless(None is grad(a1.outputs[0], 'wtf')) self.failUnless(isinstance(g, TensorConstant))
self.failUnless(g.data == 0)
try:
grad(a1.outputs[0], 'wtf')
except AttributeError, e:
return
self.fail()
def test_NNone_rval(self): def test_NNone_rval(self):
"""grad: Test returning some Nones from grad""" """grad: Test returning some Nones from grad"""
o = _test_grad.O() o = _test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1,g2 = grad(a1.outputs[0], a1.inputs + ['wtf']) g0,g1,g2 = grad(a1.outputs[0], a1.inputs + [scalar('z')])
self.failUnless(o.gval0 is g0) self.failUnless(o.gval0 is g0)
self.failUnless(o.gval1 is g1) self.failUnless(o.gval1 is g1)
self.failUnless(None is g2) self.failUnless(isinstance(g2, TensorConstant))
self.failUnless(g2.data == 0)
...@@ -1714,3 +1722,4 @@ if __name__ == '__main__': ...@@ -1714,3 +1722,4 @@ if __name__ == '__main__':
suite = unittest.TestLoader() suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(testcase) suite = suite.loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(suite) unittest.TextTestRunner(verbosity=2).run(suite)
...@@ -1232,14 +1232,23 @@ def grad(cost, wrt, g_cost=None): ...@@ -1232,14 +1232,23 @@ def grad(cost, wrt, g_cost=None):
@rtype: L{Result} or list of L{Result}s (depending upon I{wrt}) @rtype: L{Result} or list of L{Result}s (depending upon I{wrt})
@return: symbolic expression of gradient of I{cost} with respect to I{wrt}. @return: symbolic expression of gradient of I{cost} with respect to I{wrt}.
If I{wrt} is a list, then return a list containing the gradient of I{cost} wrt If I{wrt} is a list, then return a list containing the gradient of I{cost} wrt
each element of the list. each element of the list. If an element of I{wrt} is not differentiable
with respect to the output, then a L{TensorConstant} with an appropriate
kind of zero is returned.
""" """
if g_cost is None: if g_cost is None:
g_cost = ones_like(cost) g_cost = ones_like(cost)
inputs = gof.graph.inputs([cost]) inputs = gof.graph.inputs([cost])
gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs) gmap = gradient.grad_sources_inputs([(cost, g_cost)], inputs)
def zero(p):
return TensorConstant(
Tensor(dtype = p.type.dtype, broadcastable = []),
numpy.asarray(0, dtype=p.type.dtype))
if isinstance(wrt, list): if isinstance(wrt, list):
return [gmap.get(p, None) for p in wrt] return [gmap.get(p, zero(p)) for p in wrt]
else: else:
return gmap.get(wrt, None) return gmap.get(wrt, zero(wrt))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论