* Recoded portions of _check_view_map (is much simpler now)

* _check_view_map will not raise an error if multiple inputs have the same id, and one of the outputs has a view_map to one of those two inputs
上级 79460528
...@@ -286,41 +286,48 @@ def _check_viewmap(node, storage_map): ...@@ -286,41 +286,48 @@ def _check_viewmap(node, storage_map):
""" """
for oi, onode in enumerate(node.outputs): for oi, onode in enumerate(node.outputs):
input_alias = None
good_alias, bad_alias = {}, {}
outstorage = storage_map[onode][0] outstorage = storage_map[onode][0]
instorage_id = [id(storage_map[i][0]) for i in node.inputs] instorage_id = [id(storage_map[i][0]) for i in node.inputs]
# TODO: investigate ways in which other Types may be aliased # TODO: investigate ways in which other Types may be aliased
# TODO: consider adding a function to Type to detect aliasing # TODO: consider adding a function to Type to detect aliasing
danger_flag = id(outstorage) in instorage_id or\ danger_flag = id(outstorage) in instorage_id or\
(type(outstorage)==numpy.ndarray and (type(outstorage)==numpy.ndarray and
outstorage.flags['OWNDATA']==False) outstorage.flags['OWNDATA']==False)
if danger_flag: if danger_flag:
# first find out which input it aliases # first find out which input it aliases
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {})
# In theory, theano's view_map only allows for 1 output to alias 1 input # In theory, theano's view_map only allows for 1 output to alias 1 input
# Checking for multiple aliases just in case... # Checking for multiple aliases just in case...
alias = {}
for ii, inode in enumerate(node.inputs): for ii, inode in enumerate(node.inputs):
if _may_share_memory(outstorage, storage_map[inode][0]): if _may_share_memory(outstorage, storage_map[inode][0]):
alias[ii] = (ii,inode)
# if its aliased but its declared in the view/destroy map = OK nodeid = id(inode)
viewmapped = False bad_alias[nodeid] = ii
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {}) # check that the aliasing was declared in [view|destroy]_map
for key,val in view_map.items()+destroy_map.items(): if ([ii]==view_map.get(oi,None) or\
val = val[0] # view_map stores a list with single-entries [ii]==destroy_map.get(oi,None)):
if key==oi and val in alias.keys():
# pfeew, its viewmapped. we're good good_alias[nodeid] = bad_alias.pop(nodeid)
input_alias = alias.pop(val)
#TODO: make sure this is correct
# if there's anything left in alias, there's a problem # According to OB, duplicate inputs are rejected on build graph time
if len(alias): # if they cause problems. So if they are here it should be ok.
raise BadViewMap(node, oi, outstorage, alias.keys()) for key,val in good_alias.iteritems():
bad_alias.pop(key, None)
if bad_alias:
raise BadViewMap(node, oi, outstorage, bad_alias.values())
#need to check output->output aliasing as well #if its not aliased to input, check output->output aliasing
if not input_alias and _is_used_in_graph(onode): if not good_alias and _is_used_in_graph(onode):
for other_oi, other_onode in enumerate(node.outputs): for other_oi, other_onode in enumerate(node.outputs):
if other_oi==oi: continue if other_oi==oi: continue
......
...@@ -89,6 +89,9 @@ def uniq(seq): ...@@ -89,6 +89,9 @@ def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2): def difference(seq1, seq2):
"""
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
"""
try: try:
# try to use O(const * len(seq1)) algo # try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB if len(seq2) < 4: # I'm guessing this threshold -JB
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论