made tests use DualLinker

上级 e22a5f23
...@@ -9,6 +9,16 @@ import gof, gof.graph ...@@ -9,6 +9,16 @@ import gof, gof.graph
from gof.python25 import any from gof.python25 import any
def _numpy_checker(x, y):
"""
Checks if x.data and y.data have the same contents.
Used in DualLinker to compare C version with Python version.
"""
if (x.data != y.data).any():
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data})
#TODO: consider moving this function / functionality to gradient.py #TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify # rationale: it's tricky, and necessary everytime you want to verify
# gradient numerically # gradient numerically
...@@ -82,6 +92,11 @@ def check_eq2_c(self, inputs, output, args_in, arg_out): ...@@ -82,6 +92,11 @@ def check_eq2_c(self, inputs, output, args_in, arg_out):
val = fn(*args_in) val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out)) self.failUnless( numpy.all(val == arg_out), (val, arg_out))
def check_eq2_both(self, inputs, output, args_in, arg_out):
fn = Function(inputs, [output], linker_cls = lambda env: gof.DualLinker(env, _numpy_checker))
val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out))
class T_argmax(unittest.TestCase): class T_argmax(unittest.TestCase):
def setUp(self): def setUp(self):
numpy.random.seed(123784) numpy.random.seed(123784)
...@@ -408,24 +423,24 @@ class T_mul(unittest.TestCase): ...@@ -408,24 +423,24 @@ class T_mul(unittest.TestCase):
def test_elemwise(self): def test_elemwise(self):
a = astensor(0.0) a = astensor(0.0)
b = astensor(0.0) b = astensor(0.0)
check_eq2(self, [a,b], mul_elemwise(a,b), [3.0, 4.0], 12.0) check_eq2_both(self, [a,b], mul_elemwise(a,b), [3.0, 4.0], 12.0)
check_eq2(self, [a,b], mul_elemwise(b,a), [-1.0,2.0], -2.0) check_eq2_both(self, [a,b], mul_elemwise(b,a), [-1.0,2.0], -2.0)
self.failUnless(isinstance(mul(a,b).owner, Scale)) self.failUnless(isinstance(mul(a,b).owner, Scale))
a = astensor(numpy.ones(2)) a = astensor(numpy.ones(2))
b = astensor(numpy.ones(2)) b = astensor(numpy.ones(2))
aa = numpy.asarray([-0.5, 4.0]) aa = numpy.asarray([-0.5, 4.0])
bb = numpy.asarray([-0.5, 2.0]) bb = numpy.asarray([-0.5, 2.0])
check_eq2(self, [a,b], mul_elemwise(a,b), [aa,bb], numpy.asarray([0.25, 8.0])) check_eq2_both(self, [a,b], mul_elemwise(a,b), [aa,bb], numpy.asarray([0.25, 8.0]))
check_eq2(self, [a,b], mul_elemwise(a,b), [bb,aa], numpy.asarray([0.25, 8.0])) check_eq2_both(self, [a,b], mul_elemwise(a,b), [bb,aa], numpy.asarray([0.25, 8.0]))
self.failUnless(isinstance(mul(a,b).owner, MulElemwise)) self.failUnless(isinstance(mul(a,b).owner, MulElemwise))
def test_scalar(self): def test_scalar(self):
r = numpy.random.rand(2,3) r = numpy.random.rand(2,3)
a = astensor(r) a = astensor(r)
b = astensor(2.0) b = astensor(2.0)
check_eq2(self, [a,b], scale(a,b), [r, 2.0], r*2.0) check_eq2_both(self, [a,b], scale(a,b), [r, 2.0], r*2.0)
check_eq2(self, [a,b], scale(a,b), [r, 4.0], r*4.0) check_eq2_both(self, [a,b], scale(a,b), [r, 4.0], r*4.0)
self.failUnless(b.data == 2.0) self.failUnless(b.data == 2.0)
def test_operator(self): def test_operator(self):
...@@ -443,10 +458,16 @@ class T_mul(unittest.TestCase): ...@@ -443,10 +458,16 @@ class T_mul(unittest.TestCase):
try: try:
check_eq2(self, [a,b], MulElemwise(a,b).out, check_eq2(self, [a,b], MulElemwise(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0) [numpy.ones(3), numpy.ones(4)], 1.0)
self.fail()
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is tensor._assert_same_shapes.E_shape) self.failUnless(e[0] is tensor._assert_same_shapes.E_shape)
return
try:
check_eq2_c(self, [a,b], MulElemwise(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0)
self.fail() self.fail()
except ValueError, e:
pass
class T_div(unittest.TestCase): class T_div(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论