提交 bf0e90ff authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made gradient.grad use the upgraded connection_type method

上级 ad865eb3
...@@ -438,7 +438,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -438,7 +438,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
if not using_list and not using_tuple: if not using_list and not using_tuple:
wrt = [wrt] wrt = [wrt]
var_to_node_to_idx = _populate_var_to_node_to_idx([cost]) var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
...@@ -492,27 +492,78 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False, ...@@ -492,27 +492,78 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
return rval return rval
def _populate_var_to_node_to_idx(outputs): def _populate_var_to_node_to_idx(outputs, wrt):
""" """
Common code shared between grad and grad_sources_inputs Common code shared between grad and grad_sources_inputs
outputs: a list of nodes we want to take gradients of outputs: a list of variables we want to take gradients of
returns: wrt: a list of variables we want to take the gradient with
var_to_node_to_idx: a dictionary mapping a variable to respect to.
a second dictionary.
the second dictionary maps apply nodes acting on returns:
this variable to the variable's index in the apply var_to_node_to_idx: a dictionary mapping a variable to
node's input list a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
node's input list
This dictionary will only contain variables that
meet two criteria:
1) The elements of at least one output are a
function of the elements of the variable
2) The elements of the variable are a function
of the elements of at least one member of
wrt
This set is exactly the set of variables that
connect the variables in wrt to the cost being
differentiated.
""" """
def node_to_pattern(node):
# given an apply node, obtain its connection pattern
# this is just a wrapper around Op.connection_pattern
# that does type checking and supplies the default value
# if the method is not implemented
if hasattr(node.op,'connection_pattern'):
connection_pattern = node.op.connection_pattern()
if not isinstance(connection_pattern, list):
raise TypeError("Op.connection_pattern should return " + \
("list of list of bool, but for Op=%s" % node.op) +\
"got %s with type %s." % (connection_pattern,
type(connection_pattern)))
if len(connection_pattern) != len(node.inputs):
raise ValueError('%s.connection_pattern should have %d' %
(node.op, len(node.inputs)) + 'rows but has %d.' %
len(connection_pattern))
for ii, output_pattern in enumerate(connection_pattern):
if not isinstance(output_pattern, list):
raise TypeError('%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii\
+ 'is %s of type %s.' % (output_pattern,
type(output_pattern)))
else:
connection_pattern = \
[[True for output in node.outputs]
for ipt in node.inputs]
assert isinstance(connection_pattern,list)
assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list)
assert len(connection_pattern[ii]) == \
len(node.outputs)
return connection_pattern
# var_to_node_to_idx[var][node] = [i,j] means node has # var_to_node_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j # var as input at positions i and j
var_to_node_to_idx = {} var_to_node_to_idx = {}
# set of variables or nodes that have been added to their parents # set of variables or nodes that have been added to their true parents
# ('true' here means that the elements of the variable are a function
# of the elements of the parent, according to the op's
# connection_pattern)
accounted_for = set([]) accounted_for = set([])
def account_for(var): def account_for(var):
if var in accounted_for: if var in accounted_for:
return return
...@@ -521,7 +572,18 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -521,7 +572,18 @@ def _populate_var_to_node_to_idx(outputs):
node = var.owner node = var.owner
if node not in accounted_for: if node not in accounted_for:
accounted_for.add(node) accounted_for.add(node)
connection_pattern = node_to_pattern(node)
var_idx = node.outputs.index(var)
for i, ipt in enumerate(node.inputs): for i, ipt in enumerate(node.inputs):
#don't process ipt if it is not a true
#parent of var
if not connection_pattern[i][var_idx]:
continue
if ipt not in var_to_node_to_idx: if ipt not in var_to_node_to_idx:
var_to_node_to_idx[ipt] = {} var_to_node_to_idx[ipt] = {}
node_to_idx = var_to_node_to_idx[ipt] node_to_idx = var_to_node_to_idx[ipt]
...@@ -532,9 +594,38 @@ def _populate_var_to_node_to_idx(outputs): ...@@ -532,9 +594,38 @@ def _populate_var_to_node_to_idx(outputs):
idx.append(i) idx.append(i)
account_for(ipt) account_for(ipt)
# add all variables that are true ancestors of the cost
for output in outputs: for output in outputs:
account_for(output) account_for(output)
# determine which variables have elements of wrt as a true
# ancestor. Do this with an upward pass starting from wrt,
# following only true connections
visited = set([])
def visit(var):
if var in visited:
return
if var not in var_to_node_to_idx:
return
visited.add(var)
nodes = var_to_node_to_idx[var]
for node in nodes:
connection_pattern = node_to_pattern(node)
for idx in nodes[node]:
for ii, output in enumerate(node.outputs):
if connection_pattern[idx][ii]:
visit(output)
for elem in wrt:
visit(elem)
# Remove variables that don't have wrt as a true ancestor
orig_vars = list(var_to_node_to_idx.keys())
for var in orig_vars:
if var not in visited:
del var_to_node_to_idx[var]
return var_to_node_to_idx return var_to_node_to_idx
...@@ -664,11 +755,6 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -664,11 +755,6 @@ def _populate_grad_dict(var_to_node_to_idx,
for node in node_to_idx: for node in node_to_idx:
for idx in node_to_idx[node]: for idx in node_to_idx[node]:
if hasattr(node.op, 'connection_pattern'):
pattern = node.op.connection_pattern()
if not pattern[idx]:
continue
term = access_term_cache(node)[idx] term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable): if not isinstance(term, gof.Variable):
...@@ -686,9 +772,15 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -686,9 +772,15 @@ def _populate_grad_dict(var_to_node_to_idx,
continue continue
terms.append(term) terms.append(term)
#the next line is like sum(terms) but doesn't add an
#extraneous TensorConstant(0) # Add up the terms to get the total gradient on this variable
grad_dict[var] = reduce(lambda x,y: x+y, terms) if len(terms) > 0:
# the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x,y: x+y, terms)
else:
grad_dict[var] = DisconnectedType()()
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
...@@ -774,7 +866,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -774,7 +866,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
wrt = graph_inputs wrt = graph_inputs
var_to_node_to_idx = _populate_var_to_node_to_idx(outputs) var_to_node_to_idx = _populate_var_to_node_to_idx(outputs, wrt)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
grad_dict = {} grad_dict = {}
......
...@@ -1547,11 +1547,25 @@ class Second(BinaryScalarOp): ...@@ -1547,11 +1547,25 @@ class Second(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(y)s;" % locals() return "%(z)s = %(y)s;" % locals()
def connection_pattern(self):
# x is never connected because its elements are never used
# y is connected because its elements are copied over
return [[False],[True]]
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
if y.type in continuous_types: if y.type in continuous_types:
return None, gz # x is disconnected because the elements of x are not used
return DisconnectedType()(), gz
else: else:
return None, None #when y is discrete, we assume the function can be extended
#to deal with real-valued inputs by rounding them to the
#nearest integer. f(x+eps) thus equals f(x) so the gradient
#is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like()
second = Second(transfer_type(1), name='second') second = Second(transfer_type(1), name='second')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论