made _test_tensor.make_tester to facilitate thorough testing of valid/invalid inputs/outputs

上级 fa35dffc
......@@ -8,6 +8,7 @@ import gradient
import gof, gof.graph
from gof.python25 import any
import gof
from gof.utils import AbstractFunctionError
from elemwise import DimShuffle
......@@ -20,6 +21,164 @@ def _numpy_checker(x, y):
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data})
def make_tester(name, op_class, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = None):
if grad is None:
grad = good
_op_class, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op_class, expected, checks, good, bad_build, bad_runtime, grad
class Checker(unittest.TestCase):
op_class = _op_class
expected = staticmethod(_expected)
checks = _checks
good = _good
bad_build = _bad_build
bad_runtime = _bad_runtime
grad = _grad
def test_good(self):
for testname, inputs in self.good.items():
try:
op = self.op_class(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
f = Function(op.inputs, op.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False,
optimizer = None)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \
% (self.op_class.__name__, testname, op)
value.args = value.args + (err_msg, )
raise type, value, traceback
expecteds = self.expected(*inputs)
try:
results = f(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while calling %s on the inputs %s" \
% (self.op_class.__name__, testname, op, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, )
for i, (result, expected) in enumerate(zip(results, expecteds)):
if result.dtype != expected.dtype or numpy.any(abs(result - expected) > 1e-10):
self.fail("With data %s::%s: Output %s of %s gave the wrong value. With inputs %s, expected %s, got %s."
% (self.op_class.__name__, testname, i, op, inputs, expected, result))
for description, check in self.checks.items():
if not check(inputs, results):
self.fail("With data %s::%s: %s failed the following check: %s (inputs were %s)"
% (self.op_class.__name__, testname, op, description, inputs))
def test_bad_build(self):
for testname, inputs in self.bad_build.items():
try:
op = self.op_class(*inputs)
except:
return
self.fail("With data %s::%s: %s was successfully instantiated on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs))
def test_bad_runtime(self):
for testname, inputs in self.bad_runtime.items():
try:
op = self.op_class(*inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to build a %s instance with inputs %s" \
% (self.op_class.__name__, testname, self.op_class, inputs)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
f = Function(op.inputs, op.outputs,
linker_cls = lambda env: gof.DualLinker(env, checker = _numpy_checker),
unpack_single = False,
optimizer = None)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while trying to make a function out of %s" \
% (self.op_class.__name__, testname, op)
value.args = value.args + (err_msg, )
raise type, value, traceback
try:
results = f(*inputs)
except:
return
self.fail("With data %s::%s: %s was successfully called on the following bad inputs: %s"
% (self.op_class.__name__, testname, op, inputs))
def test_grad(self):
for testname, inputs in self.grad.items():
try:
verify_grad(self, self.op_class, inputs)
except:
type, value, traceback = sys.exc_info()
err_msg = "With data %s::%s: This error occurred while computing the gradient for %s" \
% (self.op_class.__name__, testname, self.op_class)
value.args = value.args + (err_msg, )
raise type, value, traceback
Checker.__name__ = name
return Checker
rand = numpy.random.rand
randint = lambda *shape: numpy.random.random_integers(-10, 10, shape)
AddTester = make_tester(name = 'AddTester',
op_class = Add,
expected = lambda x, y: x + y,
checks = {},
good = dict(same_shapes = (rand(5, 6), rand(5, 6)),
scalar = (rand(5, 6), rand(1, 1)),
row = (rand(5, 6), rand(1, 6)),
column = (rand(5, 6), rand(5, 1)),
integers = (randint(5, 6), randint(5, 6)),
dtype_mixup = (rand(5, 6), randint(5, 6))),
bad_build = dict(not_same_dimensions = (rand(5), rand(5, 5))),
bad_runtime = dict(bad_shapes = (rand(5, 6), rand(6, 5)),
bad_row = (rand(5, 6), rand(1, 5))),
grad = {})
AddInplaceTester = make_tester(name = 'AddInplaceTester',
op_class = AddInplace,
expected = lambda x, y: numpy.array(x + y, dtype = x.dtype),
checks = dict(inplace_check = lambda (x, y), (z, ): x is z),
good = dict(same_shapes = (rand(5, 6), rand(5, 6)),
dtype_mixup = (randint(5, 6), rand(5, 6))),
bad_build = dict(not_same_dimensions = (rand(5), rand(5, 5))),
bad_runtime = dict(bad_shapes = (rand(5, 6), rand(6, 5)),
bad_row = (rand(5, 6), rand(1, 5))),
grad = {})
DotTester = make_tester(name = 'DotTester',
op_class = Dot,
expected = lambda x, y: numpy.dot(x, y),
checks = {},
good = dict(correct1 = (rand(5, 7), rand(7, 5)),
correct2 = (rand(5, 7), rand(7, 9))),
bad_build = dict(),
bad_runtime = dict(bad1 = (rand(5, 7), rand(5, 7)),
bad2 = (rand(5, 7), rand(8, 3))))
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论