提交 4b1a5597 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

wrote a new grad method that only computes the grads the user requested

上级 ce0a8359
...@@ -20,6 +20,7 @@ from theano import gof ...@@ -20,6 +20,7 @@ from theano import gof
from theano.gof import Variable from theano.gof import Variable
from theano.gof.python25 import all from theano.gof.python25 import all
import theano.gof.utils import theano.gof.utils
tensor = None
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
_msg_badlen = 'op.grad(...) returned wrong number of gradients' _msg_badlen = 'op.grad(...) returned wrong number of gradients'
...@@ -496,7 +497,138 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -496,7 +497,138 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient # Gradient
######################### #########################
def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignored',
disconnected_inputs = 'raise'):
global tensor
if tensor is None:
from theano import tensor
if consider_constant is None:
consider_constant = []
else:
#error checking on consider_constant: verify that it is a collection
# of theano variables
# this is important, if someone accidentally passes a nested data
# structure with theano variables at the leaves, only the root will
# be properly considered constant
if not hasattr(consider_constant, '__iter__'):
raise TypeError('consider_constant must be an iterable collection,'
' got ' + str(type(consider_constant)))
for elem in consider_constant:
if not isinstance(elem, gof.Variable):
raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem)))
using_list = isinstance(wrt,list)
using_tuple = isinstance(wrt,tuple)
if not using_list and not using_tuple:
wrt = [ wrt ]
#set of variables that has had children added to it
marked = set([])
#set of variables that have been added to their parents
accounted_for = set([])
#use a try/finally to make sure we don't leave any marks
#on the variables
try:
#mark the variables in the relevant subgraph with
#a dictionary called chidlren
#var._children[node] gives the index of var in _children.inputs
def account_for(var):
if var in accounted_for:
return
accounted_for.add(var)
if var.owner is not None:
node = var.owner
for i, ipt in enumerate(node.inputs):
if not hasattr(ipt, '_children'):
marked.add(ipt)
ipt._children = {}
if node not in ipt._children:
ipt._children[node] = i
account_for(ipt)
account_for(cost)
#build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
#by default, the gradient of the cost is 1
if g_cost is None:
g_cost = tensor.ones_like(cost)
grad_dict[cost] = g_cost
#the gradient of the constants is 0
for const in consider_constant:
grad_dict[cost] = tensor.zeros_like(cost)
#variables that do not influence the cost have zero gradient.
#if wrt is such a varibale, populate the grad_dict with this info
#so that wrt not having _children won't cause an error below
#according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in marked:
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % elem)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=1)
elif disconnected_inputs == 'raise':
raise ValueError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
grad_dict[elem] = elem.zeros_like()
#build a dict mapping node to the terms node contributes to each of its inputs' gradients
term_dict = {}
#populate term_dict[node] and return it
def access_term_cache(node):
if node not in term_dict:
term_dict[node] = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs])
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = tensor.zeros_like(node.inputs[i])
return term_dict[node]
#populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
if hasattr(var,'_children'):
terms = []
for child in var._children.keys():
idx = var._children[child]
terms.append( access_term_cache(child)[idx])
grad_dict[var] = sum(terms)
if cost.name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost.name, var.name)
else:
#this variable is not connected to the cost in the computational
#graph so the gradient on it is zero
grad_dict[var] = tensor.zeros_like(var)
return grad_dict[var]
rval = [ access_grad_cache(elem) for elem in wrt ]
finally:
#take the marks out
for node in marked:
del node._children
if using_tuple:
rval = tuple(rval)
elif not using_list:
rval ,= rval
return rval
def grad_wrong(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
disconnected_inputs='raise'): disconnected_inputs='raise'):
""" """
:type cost: Scalar (0-dimensional) Variable. :type cost: Scalar (0-dimensional) Variable.
...@@ -530,6 +662,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -530,6 +662,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
`theano.gradient.grad_sources_inputs``. `theano.gradient.grad_sources_inputs``.
""" """
if consider_constant is None: if consider_constant is None:
consider_constant = [] consider_constant = []
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论