提交 44adeac0 authored 作者: James Bergstra's avatar James Bergstra

added option to check strides of arrays returned by C functions. defaults to…

added option to check strides of arrays returned by C functions. defaults to False because it is overly strict
上级 e8e4d2ef
...@@ -54,6 +54,19 @@ class BadCLinkerOutput(DebugModeError): ...@@ -54,6 +54,19 @@ class BadCLinkerOutput(DebugModeError):
"""Return the Op class whose c_code and perform implementations didn't match""" """Return the Op class whose c_code and perform implementations didn't match"""
return type(self.r.owner.op) return type(self.r.owner.op)
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""Return a pretty multiline string representating the cause of the exception"""
sio = StringIO()
print >> sio, "BadCLinkerOutput"
print >> sio, " variable:", self.r
print >> sio, " val_py :", self.val_py
print >> sio, " val_c :", self.val_c
print >> sio, " op :", self.offending_op()
return sio.getvalue()
class BadOptimization(DebugModeError): class BadOptimization(DebugModeError):
"""Exception: some variable and its substitute take different runtime values. """Exception: some variable and its substitute take different runtime values.
""" """
...@@ -358,6 +371,14 @@ def _is_function_output(node): ...@@ -358,6 +371,14 @@ def _is_function_output(node):
def _is_used_in_graph(node): def _is_used_in_graph(node):
return not(_is_function_output(node) or node.clients==[]) return not(_is_function_output(node) or node.clients==[])
def _check_strides_match(a, b):
try:
strides_eq = a.strides == b.strides
except:
return # no strides
if not strides_eq:
raise TypeError('BAD STRIDES', a.strides, b.strides)
def _lessbroken_deepcopy(a): def _lessbroken_deepcopy(a):
""" """
...@@ -864,6 +885,10 @@ class _Linker(gof.link.LocalLinker): ...@@ -864,6 +885,10 @@ class _Linker(gof.link.LocalLinker):
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0])
# check for stride correctness if we're doing that
if self.maker.mode.require_matching_strides:
_check_strides_match(r_vals[r], storage_map[r][0])
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set, _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
clobber_dr_vals=False) clobber_dr_vals=False)
...@@ -1206,6 +1231,11 @@ class DebugMode(Mode): ...@@ -1206,6 +1231,11 @@ class DebugMode(Mode):
Should we check for (and complain about) NaN/Inf ndarray elements? Should we check for (and complain about) NaN/Inf ndarray elements?
""" """
require_matching_strides = False
"""
Should we check for (and complain about) Ops whose python and C outputs are ndarrays with
different strides? (This can catch bugs, but is generally overly strict.)
"""
# This function will be used to create a FunctionMaker in # This function will be used to create a FunctionMaker in
# function_module.function # function_module.function
...@@ -1219,7 +1249,8 @@ class DebugMode(Mode): ...@@ -1219,7 +1249,8 @@ class DebugMode(Mode):
stability_patience=None, stability_patience=None,
check_c_code=None, check_c_code=None,
check_py_code=None, check_py_code=None,
check_isfinite=None): check_isfinite=None,
require_matching_strides=None):
"""Initialize member variables. """Initialize member variables.
If any of these arguments (except optimizer) is not None, it overrides the class default. If any of these arguments (except optimizer) is not None, it overrides the class default.
...@@ -1240,6 +1271,9 @@ class DebugMode(Mode): ...@@ -1240,6 +1271,9 @@ class DebugMode(Mode):
if check_isfinite is not None: if check_isfinite is not None:
self.check_isfinite = check_isfinite self.check_isfinite = check_isfinite
if require_matching_strides is not None:
self.require_matching_strides = require_matching_strides
if not (self.check_c_code or self.check_py_code): if not (self.check_c_code or self.check_py_code):
raise ValueError('DebugMode has to check at least one of c and py code') raise ValueError('DebugMode has to check at least one of c and py code')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论