提交 4e4a21b2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improvements to _numpy_checker and make_tester

上级 4e111a5c
...@@ -17,8 +17,9 @@ def _numpy_checker(x, y): ...@@ -17,8 +17,9 @@ def _numpy_checker(x, y):
Checks if x.data and y.data have the same contents. Checks if x.data and y.data have the same contents.
Used in DualLinker to compare C version with Python version. Used in DualLinker to compare C version with Python version.
""" """
if (x.data != y.data).any(): x, y = x.data, y.data
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = None): def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = None):
...@@ -39,6 +40,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -39,6 +40,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
def test_good(self): def test_good(self):
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
inputs = [copy(input) for input in inputs]
try: try:
op = self.op_class(*inputs) op = self.op_class(*inputs)
except: except:
...@@ -74,7 +76,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -74,7 +76,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
if not isinstance(expecteds, (list, tuple)): if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, ) expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)): for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or numpy.any(abs(result - expected) > 1e-10): if result.dtype != expected.dtype or result.shape != expected.shape or numpy.any(abs(result - expected) > 1e-10):
self.fail("With data %s::%s: Output %s of %s gave the wrong value. With inputs %s, expected %s, got %s." self.fail("With data %s::%s: Output %s of %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op_class.__name__, testname, i, op, inputs, expected, result)) % (self.op_class.__name__, testname, i, op, inputs, expected, result))
...@@ -85,6 +87,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -85,6 +87,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
def test_bad_build(self): def test_bad_build(self):
for testname, inputs in self.bad_build.items(): for testname, inputs in self.bad_build.items():
inputs = [copy(input) for input in inputs]
try: try:
op = self.op_class(*inputs) op = self.op_class(*inputs)
except: except:
...@@ -94,6 +97,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -94,6 +97,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
def test_bad_runtime(self): def test_bad_runtime(self):
for testname, inputs in self.bad_runtime.items(): for testname, inputs in self.bad_runtime.items():
inputs = [copy(input) for input in inputs]
try: try:
op = self.op_class(*inputs) op = self.op_class(*inputs)
except: except:
...@@ -125,6 +129,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -125,6 +129,7 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
def test_grad(self): def test_grad(self):
for testname, inputs in self.grad.items(): for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs]
try: try:
verify_grad(self, self.op_class, inputs) verify_grad(self, self.op_class, inputs)
except: except:
...@@ -138,9 +143,35 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {} ...@@ -138,9 +143,35 @@ def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}
return Checker return Checker
rand = numpy.random.rand rand = lambda *shape: 2 * numpy.random.rand(*shape) - 1
randint = lambda *shape: numpy.random.random_integers(-10, 10, shape) randint = lambda *shape: numpy.random.random_integers(-10, 10, shape)
def randint_notzero(*shape):
r = numpy.random.random_integers(-10, 9, shape)
return r + (r == 0) * 10
randplus = numpy.random.rand
_good_broadcast = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)),
row = (rand(2, 3), rand(1, 3)),
column = (rand(2, 3), rand(2, 1)),
integers = (randint(2, 3), randint(2, 3)),
dtype_mixup = (rand(2, 3), randint(2, 3)))
_bad_build_broadcast = dict(not_same_dimensions = (rand(2), rand(2, 2)))
_bad_runtime_broadcast = dict(not_same_dimensions = (rand(2), rand(2, 2)))
_grad_broadcast = _good_broadcast
AddTester = make_tester(name = 'AddTester', AddTester = make_tester(name = 'AddTester',
op_class = Add, op_class = Add,
......
...@@ -32,16 +32,24 @@ inplace_optimizer = InplaceOptimizer() ...@@ -32,16 +32,24 @@ inplace_optimizer = InplaceOptimizer()
# self. # self.
# def find_elemwise_cliques(env): # def find_elemwise_cliques(env, cross_broadcast = False):
# def synchronize(env1, env2, equiv, transform):
# class Synchronize(Listener, Constraint):
# def on_import(self, op1):
# if op1 not in equiv:
# equiv[op1] = transform(op1)
# def on_prune(self, op1):
# if op1 in equiv:
# del equiv[op1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论