提交 e7d32ce6 authored 作者: bergstra@ip05.m's avatar bergstra@ip05.m

debugmode now checks strides all the time, with the option of raising an…

debugmode now checks strides all the time, with the option of raising an exception instead of logging to stderr
上级 563b7086
...@@ -371,14 +371,18 @@ def _is_function_output(node): ...@@ -371,14 +371,18 @@ def _is_function_output(node):
def _is_used_in_graph(node): def _is_used_in_graph(node):
return not(_is_function_output(node) or node.clients==[]) return not(_is_function_output(node) or node.clients==[])
def _check_strides_match(a, b): def _check_strides_match(a, b, raise_on_err, op):
try: try:
strides_eq = a.strides == b.strides strides_eq = a.strides == b.strides
except: except:
return # no strides return # no strides
if not strides_eq: if not strides_eq:
raise TypeError('BAD STRIDES', a.strides, b.strides) e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides, b.strides, str(op)))
if raise_on_err:
raise e
else:
print >> sys.stderr, 'WARNING:', e
def _lessbroken_deepcopy(a): def _lessbroken_deepcopy(a):
""" """
...@@ -888,8 +892,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -888,8 +892,8 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0])
# check for stride correctness if we're doing that # check for stride correctness if we're doing that
if self.maker.mode.require_matching_strides: _check_strides_match(r_vals[r], storage_map[r][0],
_check_strides_match(r_vals[r], storage_map[r][0]) self.maker.mode.require_matching_strides, node.op)
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set, _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=False) clobber_dr_vals=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论