提交 1957e938 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

got rid of an extra 0 in a sum that was making pattern matching for unit

tests harder
上级 7ca064a9
...@@ -22,6 +22,7 @@ from theano.gof.python25 import all ...@@ -22,6 +22,7 @@ from theano.gof.python25 import all
import theano.gof.utils import theano.gof.utils
tensor = None tensor = None
from theano.gof.nan_type import NaNType from theano.gof.nan_type import NaNType
from theano.printing import min_informative_str
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -556,6 +557,16 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -556,6 +557,16 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
term_dict[node][i] = tensor.zeros_like(node.inputs[i]) term_dict[node][i] = tensor.zeros_like(node.inputs[i])
return term_dict[node] return term_dict[node]
#built-in python sum adds an extraneous TensorConstant(0)
#we can exploit the knowledge that iterable always has at
#least one element to avoid starting the sum at 0
def nonempty_sum( iterable ):
rval = iterable[0]
for elem in iterable[1:]:
rval = rval + elem
return rval
#populate grad_dict[var] and return it #populate grad_dict[var] and return it
def access_grad_cache(var): def access_grad_cache(var):
if var not in grad_dict: if var not in grad_dict:
...@@ -564,7 +575,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -564,7 +575,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
for child in var._children.keys(): for child in var._children.keys():
idx = var._children[child] idx = var._children[child]
terms.append( access_term_cache(child)[idx]) terms.append( access_term_cache(child)[idx])
grad_dict[var] = sum(terms) grad_dict[var] = nonempty_sum(terms)
if cost.name is not None and var.name is not None: if cost.name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name)
else: else:
......
...@@ -4377,6 +4377,8 @@ class test_grad(unittest.TestCase): ...@@ -4377,6 +4377,8 @@ class test_grad(unittest.TestCase):
o = test_grad.O() o = test_grad.O()
a1 = o.make_node() a1 = o.make_node()
g0,g1 = grad(a1.outputs[0], a1.inputs) g0,g1 = grad(a1.outputs[0], a1.inputs)
g0.name = None
print theano.printing.min_informative_str(g0)
self.assertTrue(o.gval0 is g0) self.assertTrue(o.gval0 is g0)
self.assertTrue(o.gval1 is g1) self.assertTrue(o.gval1 is g1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论