提交 aa055460 authored 作者: Frederic's avatar Frederic

[CRASH], fix some crash due to DebugMode doing useless check and make

TypeList.may_share_memory more resilent to different input type.
上级 0f89dd53
......@@ -666,14 +666,26 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
return fgraph, map(SymbolicOutput, updates), equivalence_tracker
class DataDestroyed():
# this is a singleton class We put it in the storage_map when the
# variable value was destroyed to prevent reusing bad value for
# it.
pass
data_destroyed = DataDestroyed()
def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
clobber_dr_vals=True,
perform=None, warn_input_not_reused=True):
"""
Raise BadDestroyMap if necessary, update dr_vals
"""Raise BadDestroyMap if necessary, update dr_vals
Returns a list of output variables that actually worked inplace
(their value is aliased to the value of at least one input).
It modify the storage_map to remove node.inputs variable that have
been destroyed.
"""
destroyed_idx_list = []
destroy_map = getattr(node.op, 'destroy_map', {})
......@@ -736,7 +748,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
raise Exception('failure in topological ordering')
if clobber_dr_vals:
dr_vals[r] = (storage_map[r][0], node) #no copy, this is the last use of this variable
storage_map[r][0] = None #make sure that dr_vals[r] doens't get used again
# make sure that dr_vals[r] doens't get used again
storage_map[r][0] = data_destroyed
else:
raise BadDestroyMap(node, r_idx, r_vals[r],
storage_map[r][0], perform)
......@@ -766,8 +779,15 @@ def _check_viewmap(node, storage_map):
# case...
for ii, inode in enumerate(node.inputs):
in_storage = storage_map[inode][0]
if in_storage is data_destroyed:
# If the input have been destroyed, it can't be a
# view. So no need to check. Also, we don't have the
# original value, we we wouldn't be able to do this
# useless check.
continue
if hasattr(inode.type, 'may_share_memory') and\
inode.type.may_share_memory(outstorage, storage_map[inode][0]):
inode.type.may_share_memory(outstorage, in_storage):
nodeid = id(inode)
bad_alias[nodeid] = ii
......
......@@ -79,6 +79,13 @@ class TypedListType(gof.Type):
def may_share_memory(self, a, b):
if a is b:
return True
# As a list contain other element, if a or b isn't a list, we
# still need to check if that element is contained in the
# other list.
if not isinstance(a, list):
a = [a]
if not isinstance(b, list):
b = [b]
for idx1 in range(len(a)):
for idx2 in range(len(b)):
if self.ttype.may_share_memory(a[idx1], b[idx2]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论