提交 33de45ca authored 作者: Frederic's avatar Frederic

speed up debugmode. Don't use intermediate call.

Also, make less call to may_share_memory.
上级 848848fc
...@@ -685,9 +685,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -685,9 +685,10 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
actually_inplace_outputs = [] actually_inplace_outputs = []
dmap = getattr(node.op, 'destroy_map', {}) dmap = getattr(node.op, 'destroy_map', {})
for oo, ii in dmap.iteritems(): for oo, ii in dmap.iteritems():
out_var = storage_map[node.outputs[oo]][0] var = node.outputs[oo]
out_var = storage_map[var][0]
in_var = storage_map[node.inputs[ii[0]]][0] in_var = storage_map[node.inputs[ii[0]]][0]
if _may_share_memory(out_var, in_var): if var.may_share_memory(out_var, in_var):
actually_inplace_outputs.append(node.outputs[oo]) actually_inplace_outputs.append(node.outputs[oo])
if warn_input_not_reused and destroyed_res_list: if warn_input_not_reused and destroyed_res_list:
...@@ -702,9 +703,11 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -702,9 +703,11 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
vmap = getattr(node.op, 'view_map', {}) vmap = getattr(node.op, 'view_map', {})
for oo, ii in vmap.iteritems(): for oo, ii in vmap.iteritems():
out_var = storage_map[node.outputs[oo]][0] var = node.outputs[oo]
out_var = storage_map[var][0]
in_var = storage_map[node.inputs[ii[0]]][0] in_var = storage_map[node.inputs[ii[0]]][0]
if _may_share_memory(out_var, in_var): may_share = var.may_share_memory(out_var, in_var)
if may_share:
actually_inplace_outputs.append(node.outputs[oo]) actually_inplace_outputs.append(node.outputs[oo])
if warn_input_not_reused: if warn_input_not_reused:
...@@ -717,7 +720,7 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, ...@@ -717,7 +720,7 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes,
if isinstance(node.op, OutputGuard): if isinstance(node.op, OutputGuard):
# This class is not in the final graph. # This class is not in the final graph.
continue continue
if not _may_share_memory(out_var, in_var): if not may_share:
_logger.warning("Optimization Warning: input idx %d marked " _logger.warning("Optimization Warning: input idx %d marked "
"as viewed but new memory allocated by node '%s'", "as viewed but new memory allocated by node '%s'",
ii[0], str(node)) ii[0], str(node))
...@@ -766,7 +769,7 @@ def _check_viewmap(node, storage_map): ...@@ -766,7 +769,7 @@ def _check_viewmap(node, storage_map):
for ii, inode in enumerate(node.inputs): for ii, inode in enumerate(node.inputs):
if _may_share_memory(outstorage, storage_map[inode][0]): if ii.may_share_memory(outstorage, storage_map[inode][0]):
nodeid = id(inode) nodeid = id(inode)
bad_alias[nodeid] = ii bad_alias[nodeid] = ii
...@@ -794,17 +797,12 @@ def _check_viewmap(node, storage_map): ...@@ -794,17 +797,12 @@ def _check_viewmap(node, storage_map):
other_storage = storage_map[other_onode][0] other_storage = storage_map[other_onode][0]
# check to see if we share memory with this other output # check to see if we share memory with this other output
# this is not a problem if the node is not actually used # this is not a problem if the node is not actually used
if _is_used_in_graph(other_onode) and \ if (_is_used_in_graph(other_onode) and
_may_share_memory(outstorage, other_storage): other_onode.may_share_memory(outstorage, other_storage)):
raise BadViewMap(node, oi, outstorage, raise BadViewMap(node, oi, outstorage,
out_alias_idx=other_oi) out_alias_idx=other_oi)
def _may_share_memory(a, b):
from theano.misc.may_share_memory import may_share_memory
return may_share_memory(a, b, False)
def _is_function_output(node): def _is_function_output(node):
""" """
Returns True if the node in question is the a final output of the graph Returns True if the node in question is the a final output of the graph
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论