提交 e7d32ce6 authored 作者: bergstra@ip05.m's avatar bergstra@ip05.m

debugmode now checks strides all the time, with the option of raising an…

debugmode now checks strides all the time, with the option of raising an exception instead of logging to stderr
上级 563b7086
......@@ -371,14 +371,18 @@ def _is_function_output(node):
def _is_used_in_graph(node):
return not(_is_function_output(node) or node.clients==[])
def _check_strides_match(a, b):
def _check_strides_match(a, b, raise_on_err, op):
try:
strides_eq = a.strides == b.strides
except:
return # no strides
if not strides_eq:
raise TypeError('BAD STRIDES', a.strides, b.strides)
e = TypeError('Stride mismatch', (a.shape, b.shape, a.strides, b.strides, str(op)))
if raise_on_err:
raise e
else:
print >> sys.stderr, 'WARNING:', e
def _lessbroken_deepcopy(a):
"""
......@@ -888,8 +892,8 @@ class _Linker(gof.link.LocalLinker):
raise InvalidValueError(r, storage_map[r][0])
# check for stride correctness if we're doing that
if self.maker.mode.require_matching_strides:
_check_strides_match(r_vals[r], storage_map[r][0])
_check_strides_match(r_vals[r], storage_map[r][0],
self.maker.mode.require_matching_strides, node.op)
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=False)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论