提交 4faa1f3b authored 作者: Frederic Bastien's avatar Frederic Bastien

Added Theano flags DebugMode.warn_input_not_reused that default to True. This…

Added Theano flags DebugMode.warn_input_not_reused that default to True. This generate a warning in debug mode when an op was declared inplace but was not working inplace.
上级 7097093a
......@@ -45,6 +45,11 @@ AddConfigVar('DebugMode.check_strides',
"On difference: (0) - ignore, (1) warn, or (2) raise error"),
IntParam(1, lambda i: i in (0,1,2)))
AddConfigVar('DebugMode.warn_input_not_reused',
("Generate a warning when the destroy_map tell that an op work inplace, but the op did not reuse the input for its output."
),
BoolParam(True))
import logging
_logger=logging.getLogger("theano.compile.debugmode")
_logger.setLevel(logging.WARNING)
......@@ -500,7 +505,7 @@ def _optcheck_env(input_specs, output_specs, accept_inplace = False):
return env, map(SymbolicOutput, updates), equivalence_tracker
def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_vals=True,
perform=None):
perform=None, warn_input_not_reused=True):
"""Raise BadDestroyMap if necessary, update dr_vals"""
destroyed_idx_list = []
destroy_map = getattr(node.op, 'destroy_map', {})
......@@ -508,6 +513,12 @@ def _check_inputs(node, storage_map, r_vals, dr_vals, active_nodes, clobber_dr_v
destroyed_idx_list.extend(i_pos_list)
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
if warn_input_not_reused and destroyed_res_list:
dmap=getattr(node.op,'destroy_map',{})
for oo,ii in dmap.iteritems():
if storage_map[node.outputs[oo]][0] is not storage_map[node.inputs[ii[0]]][0]:
warning("input idx %d marked as destroyed was not changed for node '%s'"%(ii[0],str(node)))
for r_idx, r in enumerate(node.inputs):
if not r.type.values_eq(r_vals[r], storage_map[r][0]):
# some input node 'r' got changed by running the node
......@@ -1123,7 +1134,8 @@ class _Linker(gof.link.LocalLinker):
#if r in r_vals:
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=True, perform='py')
clobber_dr_vals=True, perform='py',
warn_input_not_reused=config.DebugMode.warn_input_not_reused)
_check_viewmap(node, storage_map)
......@@ -1181,7 +1193,8 @@ class _Linker(gof.link.LocalLinker):
self.maker.mode.require_matching_strides, node.op)
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=clobber, perform='c')
clobber_dr_vals=clobber, perform='c',
warn_input_not_reused=config.DebugMode.warn_input_not_reused)
_check_viewmap(node, storage_map)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论