提交 03971aa6 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

modular checker

上级 1ad97e45
...@@ -548,16 +548,22 @@ copy_reg.pickle(slice, _pickle_slice) ...@@ -548,16 +548,22 @@ copy_reg.pickle(slice, _pickle_slice)
__checkers = []
def check_equal(x, y):
for checker in __checkers:
try:
return checker(x, y)
except:
continue
return x == y
#raise Exception('No checker for equality between %s and %s' % (x, y))
def register_checker(checker):
__checkers.insert(0, checker)
def check_equal_numpy(x, y):
"""
Returns True iff x and y are equal (checks the dtype and
shape if x and y are numpy.ndarray instances).
"""
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10)
else:
return x == y
def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
""" """
...@@ -663,7 +669,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -663,7 +669,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
for default in defaults] for default in defaults]
makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]] makers = [FunctionMaker(inputs, outputs, m, accept_inplace = accept_inplace) for m in mode[1:]]
fns = [maker.create(dup_defaults(), trustme = True) for maker in makers] fns = [maker.create(dup_defaults(), trustme = True) for maker in makers]
builder = partial(SanityCheckFunction, fns, check_equal_numpy) builder = partial(SanityCheckFunction, fns, check_equal)
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder) maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder)
fn = maker1.create(defaults) fn = maker1.create(defaults)
else: else:
......
...@@ -27,6 +27,21 @@ from .. import compile ...@@ -27,6 +27,21 @@ from .. import compile
from elemwise import Elemwise, DimShuffle, CAReduce, Sum from elemwise import Elemwise, DimShuffle, CAReduce, Sum
def check_equal_numpy(x, y):
"""
Returns True iff x and y are equal (checks the dtype and
shape if x and y are numpy.ndarray instances).
"""
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10)
else:
return x == y
compile.register_checker(check_equal_numpy)
__oplist_constructor_list = [] __oplist_constructor_list = []
"""List of functions to be listed as op constructors in the oplist (`gen_oplist`, doc/oplist.txt).""" """List of functions to be listed as op constructors in the oplist (`gen_oplist`, doc/oplist.txt)."""
def constructor(f): def constructor(f):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论