提交 248c5921 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Alternative fix for the may_share_memory DebugMode problem.

上级 eebce83c
...@@ -744,7 +744,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -744,7 +744,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
var = node.outputs[oo] var = node.outputs[oo]
out_var = storage_map[var][0] out_var = storage_map[var][0]
in_var = storage_map[node.inputs[ii[0]]][0] in_var = storage_map[node.inputs[ii[0]]][0]
if var.type.may_share_memory(out_var, in_var): if (hasattr(var.type, 'may_share_memory') and
var.type.may_share_memory(out_var, in_var)):
actually_inplace_outputs.append(node.outputs[oo]) actually_inplace_outputs.append(node.outputs[oo])
if warn_input_not_reused and destroyed_res_list: if warn_input_not_reused and destroyed_res_list:
...@@ -762,7 +763,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -762,7 +763,8 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
var = node.outputs[oo] var = node.outputs[oo]
out_var = storage_map[var][0] out_var = storage_map[var][0]
in_var = storage_map[node.inputs[ii[0]]][0] in_var = storage_map[node.inputs[ii[0]]][0]
may_share = var.type.may_share_memory(out_var, in_var) may_share = (hasattr(var.type, 'may_share_memory') and
var.type.may_share_memory(out_var, in_var))
if may_share: if may_share:
actually_inplace_outputs.append(node.outputs[oo]) actually_inplace_outputs.append(node.outputs[oo])
...@@ -831,8 +833,8 @@ def _check_viewmap(node, storage_map): ...@@ -831,8 +833,8 @@ def _check_viewmap(node, storage_map):
# original value, we we wouldn't be able to do this # original value, we we wouldn't be able to do this
# useless check. # useless check.
continue continue
if hasattr(inode.type, 'may_share_memory') and\ if (hasattr(inode.type, 'may_share_memory') and
inode.type.may_share_memory(outstorage, in_storage): inode.type.may_share_memory(outstorage, in_storage)):
nodeid = id(inode) nodeid = id(inode)
bad_alias[nodeid] = ii bad_alias[nodeid] = ii
...@@ -861,6 +863,7 @@ def _check_viewmap(node, storage_map): ...@@ -861,6 +863,7 @@ def _check_viewmap(node, storage_map):
# check to see if we share memory with this other output # check to see if we share memory with this other output
# this is not a problem if the node is not actually used # this is not a problem if the node is not actually used
if (_is_used_in_graph(other_onode) and if (_is_used_in_graph(other_onode) and
hasattr(other_onode.type, 'may_share_memory') and
other_onode.type.may_share_memory(outstorage, other_onode.type.may_share_memory(outstorage,
other_storage)): other_storage)):
raise BadViewMap(node, oi, outstorage, raise BadViewMap(node, oi, outstorage,
......
...@@ -497,9 +497,6 @@ class Generic(SingletonType): ...@@ -497,9 +497,6 @@ class Generic(SingletonType):
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
def may_share_memory(self, other):
return self is other
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论