提交 4dd936f9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #737 from goodfeli/fix_grad_name

fixed gradient name bug
...@@ -179,8 +179,17 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -179,8 +179,17 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
_logger.warning('%s.grad returned a different type (%s) ' _logger.warning('%s.grad returned a different type (%s) '
'for input %i of type (%s)', 'for input %i of type (%s)',
node.op, g_r_type, ii, r_type) node.op, g_r_type, ii, r_type)
if g_r and len(sources) == 1 and sources[0][0].name and r.name: #The following name assignment code is broken
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) #for example, when you call
#f = T.dot(x,T.dot(A,x))
#f.name = 'f'
#T.grad( f, x)
#the result has no name, and is composed of
# A x + A^T x
# with both terms named "(df/dx)"
#if g_r is not None and len(sources) == 1 and sources[0][0].name \
# and r.name:
# g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
assert r is not None assert r is not None
if r in gmap: if r in gmap:
...@@ -519,6 +528,10 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -519,6 +528,10 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
ret.append(p.zeros_like()) ret.append(p.zeros_like())
if cost.name is not None and p.name is not None \
and ret[-1].name is None:
ret[-1].name = '(d%s/d%s)' % (cost.name, p.name)
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
...@@ -261,5 +261,13 @@ def test_unimplemented_grad(): ...@@ -261,5 +261,13 @@ def test_unimplemented_grad():
except NotImplementedError: except NotImplementedError:
pass pass
def test_grad_name():
A = theano.tensor.matrix('A')
x = theano.tensor.vector('x')
f = theano.tensor.dot(x,theano.tensor.dot(A,x))
f.name = 'f'
g = theano.tensor.grad(f,x)
assert g.name == '(df/dx)'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论