made _test_tensor.make_tester to facilitate thorough testing of valid/invalid inputs/outputs

上级 fa35dffc
...@@ -8,6 +8,7 @@ import gradient ...@@ -8,6 +8,7 @@ import gradient
import gof, gof.graph import gof, gof.graph
from gof.python25 import any from gof.python25 import any
import gof import gof
from gof.utils import AbstractFunctionError
from elemwise import DimShuffle from elemwise import DimShuffle
...@@ -20,6 +21,164 @@ def _numpy_checker(x, y): ...@@ -20,6 +21,164 @@ def _numpy_checker(x, y):
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data})
def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = None):
if grad is None:
grad = good
_op_class, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op_class, expected, checks, good, bad_build, bad_runtime, grad
class Checker(unittest.TestCase):
op_class = _op_class
expected = staticmethod(_expected)
checks = _checks
good = _good
bad_build = _bad_build
bad_runtime = _bad_runtime
grad = _grad
def test_good(self):
for testname, inputs in self.good.items():
try:
op = self.op_class(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
f = Function(op.inputs, op.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False,
optimizer = None)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \
% (self.op_class.__name__, testname, op)
value.args = value.args + (err_msg, )
raise type, value, traceback
expecteds = self.expected(*inputs)
try:
results = f(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while calling %s on the inputs %s" \
% (self.op_class.__name__, testname, op, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or numpy.any(abs(result - expected) > 1e-10):
self.fail("With data %s::%s: Output %s of %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op_class.__name__, testname, i, op, inputs, expected, result))
for description, check in self.checks.items():
if not check(inputs, results):
self.fail("With data %s::%s: %s failed the following check: %s (inputs were %s)"
% (self.op_class.__name__, testname, op, description, inputs))
def test_bad_build(self):
for testname, inputs in self.bad_build.items():
try:
op = self.op_class(*inputs)
except:
return
self.fail("With data %s::%s: %s was successfully instantiated on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs))
def test_bad_runtime(self):
for testname, inputs in self.bad_runtime.items():
try:
op = self.op_class(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
f = Function(op.inputs, op.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False,
optimizer = None)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \
% (self.op_class.__name__, testname, op)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
results = f(*inputs)
except:
return
self.fail("With data %s::%s: %s was successfully called on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs))
def test_grad(self):
for testname, inputs in self.grad.items():
try:
verify_grad(self, self.op_class, inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while computing the gradient for %s" \
% (self.op_class.__name__, testname, self.op_class)
value.args = value.args + (err_msg, )
raise type, value, traceback
Checker.__name__ = name
return Checker
rand = numpy.random.rand
randint = lambda *shape: numpy.random.random_integers(-10, 10, shape)
AddTester = make_tester(name = 'AddTester',
op_class = Add,
expected = lambda x, y: x + y,
checks = {},
good = dict(same_shapes = (rand(5, 6), rand(5, 6)),
scalar = (rand(5, 6), rand(1, 1)),
row = (rand(5, 6), rand(1, 6)),
column = (rand(5, 6), rand(5, 1)),
integers = (randint(5, 6), randint(5, 6)),
dtype_mixup = (rand(5, 6), randint(5, 6))),
bad_build = dict(not_same_dimensions = (rand(5), rand(5, 5))),
bad_runtime = dict(bad_shapes = (rand(5, 6), rand(6, 5)),
bad_row = (rand(5, 6), rand(1, 5))),
grad = {})
AddInplaceTester = make_tester(name = 'AddInplaceTester',
op_class = AddInplace,
expected = lambda x, y: numpy.array(x + y, dtype = x.dtype),
checks = dict(inplace_check = lambda (x, y), (z, ): x is z),
good = dict(same_shapes = (rand(5, 6), rand(5, 6)),
dtype_mixup = (randint(5, 6), rand(5, 6))),
bad_build = dict(not_same_dimensions = (rand(5), rand(5, 5))),
bad_runtime = dict(bad_shapes = (rand(5, 6), rand(6, 5)),
bad_row = (rand(5, 6), rand(1, 5))),
grad = {})
DotTester = make_tester(name = 'DotTester',
op_class = Dot,
expected = lambda x, y: numpy.dot(x, y),
checks = {},
good = dict(correct1 = (rand(5, 7), rand(7, 5)),
correct2 = (rand(5, 7), rand(7, 9))),
bad_build = dict(),
bad_runtime = dict(bad1 = (rand(5, 7), rand(5, 7)),
bad2 = (rand(5, 7), rand(8, 3))))
#TODO: consider moving this function / functionality to gradient.py #TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify # rationale: it's tricky, and necessary everytime you want to verify
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论