提交 e3c61377 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made consider_constant block gradient through constants but not set

their gradient to 0
上级 87cd138e
......@@ -430,21 +430,6 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
if cost.ndim != 0:
raise TypeError("cost must be a scalar.")
if consider_constant is None:
consider_constant = []
else:
# error checking on consider_constant: verify that it is a collection
# of theano variables
# this is important, if someone accidentally passes a nested data
# structure with theano variables at the leaves, only the root will
# be properly considered constant
if not hasattr(consider_constant, '__iter__'):
raise TypeError('consider_constant must be an iterable collection,'
' got ' + str(type(consider_constant)))
for elem in consider_constant:
if not isinstance(elem, gof.Variable):
raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem)))
if isinstance(wrt, set):
raise TypeError("wrt must not be a set. sets have no defined "
......@@ -461,7 +446,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt)
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
......@@ -592,7 +577,7 @@ def _node_to_pattern(node):
return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt):
def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
"""
Common code shared between grad and grad_sources_inputs
......@@ -601,6 +586,9 @@ def _populate_var_to_node_to_idx(outputs, wrt):
wrt: a list of variables we want to take the gradient with
respect to.
consider_constant: a list of variables not to backpropagate
through.
returns:
var_to_app_to_idx:
......@@ -622,8 +610,28 @@ def _populate_var_to_node_to_idx(outputs, wrt):
This set is exactly the set of variables that connect
the variables in wrt to the cost being differentiated.
(A variable in consider_constant is not a function of
anything)
"""
# Validate and format consider_constant
if consider_constant is None:
consider_constant = []
else:
# error checking on consider_constant: verify that it is a collection
# of theano variables
# this is important, if someone accidentally passes a nested data
# structure with theano variables at the leaves, only the root will
# be properly considered constant
if not hasattr(consider_constant, '__iter__'):
raise TypeError('consider_constant must be an iterable collection,'
' got ' + str(type(consider_constant)))
for elem in consider_constant:
if not isinstance(elem, gof.Variable):
raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem)))
# var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_app_to_idx = {}
......@@ -638,9 +646,17 @@ def _populate_var_to_node_to_idx(outputs, wrt):
accounted_for = set([])
def account_for(var):
# Don't visit the same variable twice
if var in accounted_for:
return
accounted_for.add(var)
# Constants are not a function of anything
if var in consider_constant:
return
# Recursively add the variables that this variable is
# a function of.
if var.owner is not None:
app = var.owner
......@@ -1066,7 +1082,7 @@ def grad_sources_inputs(sources, graph_inputs):
wrt = graph_inputs
var_to_node_to_idx = _populate_var_to_node_to_idx(outputs, wrt)
var_to_node_to_idx = _populate_var_to_node_to_idx(outputs, wrt, None)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论