* Recoded portions of _check_view_map (is much simpler now)

* _check_view_map will not raise an error if multiple inputs have the same id, and one of the outputs has a view_map to one of those two inputs
上级 79460528
......@@ -286,7 +286,8 @@ def _check_viewmap(node, storage_map):
"""
for oi, onode in enumerate(node.outputs):
input_alias = None
good_alias, bad_alias = {}, {}
outstorage = storage_map[onode][0]
instorage_id = [id(storage_map[i][0]) for i in node.inputs]
......@@ -295,32 +296,38 @@ def _check_viewmap(node, storage_map):
danger_flag = id(outstorage) in instorage_id or\
(type(outstorage)==numpy.ndarray and
outstorage.flags['OWNDATA']==False)
if danger_flag:
# first find out which input it aliases
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {})
# In theory, theano's view_map only allows for 1 output to alias 1 input
# Checking for multiple aliases just in case...
alias = {}
for ii, inode in enumerate(node.inputs):
if _may_share_memory(outstorage, storage_map[inode][0]):
alias[ii] = (ii,inode)
# if its aliased but its declared in the view/destroy map = OK
viewmapped = False
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {})
for key,val in view_map.items()+destroy_map.items():
val = val[0] # view_map stores a list with single-entries
if key==oi and val in alias.keys():
# pfeew, its viewmapped. we're good
input_alias = alias.pop(val)
# if there's anything left in alias, there's a problem
if len(alias):
raise BadViewMap(node, oi, outstorage, alias.keys())
#need to check output->output aliasing as well
if not input_alias and _is_used_in_graph(onode):
nodeid = id(inode)
bad_alias[nodeid] = ii
# check that the aliasing was declared in [view|destroy]_map
if ([ii]==view_map.get(oi,None) or\
[ii]==destroy_map.get(oi,None)):
good_alias[nodeid] = bad_alias.pop(nodeid)
#TODO: make sure this is correct
# According to OB, duplicate inputs are rejected on build graph time
# if they cause problems. So if they are here it should be ok.
for key,val in good_alias.iteritems():
bad_alias.pop(key, None)
if bad_alias:
raise BadViewMap(node, oi, outstorage, bad_alias.values())
#if its not aliased to input, check output->output aliasing
if not good_alias and _is_used_in_graph(onode):
for other_oi, other_onode in enumerate(node.outputs):
if other_oi==oi: continue
......
......@@ -89,6 +89,9 @@ def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
"""
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论