提交 9118d7f0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix non-determinism bug in gradient.grad by changing dict to OrderedDict

上级 64d1ae90
......@@ -20,6 +20,7 @@ import theano
from itertools import izip
from theano import gof
from theano.gof import Variable
from theano.gof.python25 import OrderedDict
from theano.gof.python25 import all
import theano.gof.utils
from theano.gof.null_type import NullType
......@@ -211,7 +212,7 @@ def Rop(f, wrt, eval_points):
# Tensor, Sparse and CudaNdArray have the ndim attribute
pass
seen_nodes = {}
seen_nodes = OrderedDict()
def _traverse(node):
""" TODO: writeme """
......@@ -432,14 +433,14 @@ def grad(cost, wrt, consider_constant=None,
if known_grads is not None:
outputs.extend(known_grads.keys())
var_to_node_to_idx = _populate_var_to_node_to_idx(
var_to_app_to_idx = _populate_var_to_app_to_idx(
outputs, wrt, consider_constant)
# build a dict mapping var to the gradient of cost with respect to var
grad_dict = {}
grad_dict = OrderedDict()
if known_grads is None:
known_grads = {}
known_grads = OrderedDict()
# The gradient of the cost is 1 unless specified otherwise by known_grads.
if cost is not None:
......@@ -501,10 +502,10 @@ def grad(cost, wrt, consider_constant=None,
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_node_to_idx won't cause an error below
# so that wrt not being in var_to_app_to_idx won't cause an error below
# according to the flag, possibly raise an error if wrt is disconnected
for elem in wrt:
if elem not in var_to_node_to_idx and elem is not cost \
if elem not in var_to_app_to_idx and elem is not cost \
and elem not in grad_dict:
handle_disconnected(elem)
grad_dict[elem] = DisconnectedType()()
......@@ -521,8 +522,8 @@ def grad(cost, wrt, consider_constant=None,
if hasattr(g.type, 'dtype'):
assert g.type.dtype in tensor.float_dtypes
rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name)
rval = _populate_grad_dict(var_to_app_to_idx,
for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType):
......@@ -579,7 +580,7 @@ def _node_to_pattern(node):
return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
"""
Helper function for grad function.
......@@ -638,7 +639,7 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
# var_to_app_to_idx[var][node] = [i,j] means node has
# var as input at positions i and j
var_to_app_to_idx = {}
var_to_app_to_idx = OrderedDict()
# Set of variables that have been added to their true parents
# ('true' here means that the elements of the variable are a function
......@@ -676,7 +677,13 @@ def _populate_var_to_node_to_idx(outputs, wrt, consider_constant):
continue
if ipt not in var_to_app_to_idx:
var_to_app_to_idx[ipt] = {}
# This object here *must* be an OrderedDict, because
# we iterate over its keys when adding up the terms of
# the gradient on ipt. If it is a regular dict, the grad
# method will return something that is analytically correct,
# but whose order of doing additions depends on the memory
# location of the apply nodes.
var_to_app_to_idx[ipt] = OrderedDict()
app_to_idx = var_to_app_to_idx[ipt]
if app not in app_to_idx:
app_to_idx[app] = []
......@@ -731,12 +738,12 @@ class DisconnectedInputError(ValueError):
disconnected_inputs='raise'.
"""
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, cost_name=None):
def _populate_grad_dict(var_to_app_to_idx,
"""
Helper function for grad function.
var_to_node_to_idx: a dictionary mapping a variable to
var_to_app_to_idx: a dictionary mapping a variable to
a second dictionary.
the second dictionary maps apply nodes acting on
this variable to the variable's index in the apply
......@@ -761,7 +768,7 @@ def _populate_grad_dict(var_to_node_to_idx,
"""
# build a dict mapping node to the terms node contributes to each of
# its inputs' gradients
term_dict = {}
term_dict = OrderedDict()
def access_term_cache(node):
""" Populates term_dict[node] and returns it """
......@@ -1001,15 +1008,17 @@ def _populate_grad_dict(var_to_node_to_idx,
#cache the result
term_dict[node] = input_grads
return term_dict[node]
# populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
# If var is not in grad_dict already, we must compute it
if var in var_to_node_to_idx:
if var in var_to_app_to_idx:
terms = []
node_to_idx = var_to_node_to_idx[var]
node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx:
for idx in node_to_idx[node]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论