提交 0d06246c authored 作者: Frederic's avatar Frederic

Check the view_map for all type of variable.

finish fix gh-842
上级 ddcd96d0
......@@ -780,46 +780,40 @@ def _check_viewmap(node, storage_map):
outstorage = storage_map[onode][0]
instorage_id = [id(storage_map[i][0]) for i in node.inputs]
# TODO: investigate ways in which other Types may be aliased
# TODO: consider adding a function to Type to detect aliasing
danger_flag = id(outstorage) in instorage_id or\
(type(outstorage)==numpy.ndarray and
outstorage.flags['OWNDATA']==False)
# first find out which input it aliases
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {})
if danger_flag:
# first find out which input it aliases
view_map = getattr(node.op, 'view_map', {})
destroy_map = getattr(node.op, 'destroy_map', {})
# In theory, theano's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in
# case...
# In theory, theano's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in
# case...
for ii, inode in enumerate(node.inputs):
for ii, inode in enumerate(node.inputs):
if _may_share_memory(outstorage, storage_map[inode][0]):
if _may_share_memory(outstorage, storage_map[inode][0]):
nodeid = id(inode)
bad_alias[nodeid] = ii
nodeid = id(inode)
bad_alias[nodeid] = ii
# check that the aliasing was declared in [view|destroy]_map
if ([ii] == view_map.get(oi, None) or
[ii] == destroy_map.get(oi, None)):
# check that the aliasing was declared in [view|destroy]_map
if ([ii]==view_map.get(oi,None) or\
[ii]==destroy_map.get(oi,None)):
good_alias[nodeid] = bad_alias.pop(nodeid)
good_alias[nodeid] = bad_alias.pop(nodeid)
#TODO: make sure this is correct
# According to OB, duplicate inputs are rejected on build graph time
# if they cause problems. So if they are here it should be ok.
for key, val in good_alias.iteritems():
bad_alias.pop(key, None)
if bad_alias:
raise BadViewMap(node, oi, outstorage, bad_alias.values())
#TODO: make sure this is correct
# According to OB, duplicate inputs are rejected on build graph time
# if they cause problems. So if they are here it should be ok.
for key, val in good_alias.iteritems():
bad_alias.pop(key, None)
if bad_alias:
raise BadViewMap(node, oi, outstorage, bad_alias.values())
#if its not aliased to input, check output->output aliasing
if not good_alias and _is_used_in_graph(onode):
for other_oi, other_onode in enumerate(node.outputs):
if other_oi == oi: continue
if other_oi == oi:
continue
other_storage = storage_map[other_onode][0]
# check to see if we share memory with this other output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论