made tests use DualLinker

上级 e22a5f23
......@@ -9,6 +9,16 @@ import gof, gof.graph
from gof.python25 import any
def _numpy_checker(x, y):
"""
Checks if x.data and y.data have the same contents.
Used in DualLinker to compare C version with Python version.
"""
if (x.data != y.data).any():
raise Exception("Output mismatch.", {'performlinker': x.data, 'clinker': y.data})
#TODO: consider moving this function / functionality to gradient.py
# rationale: it's tricky, and necessary everytime you want to verify
# gradient numerically
......@@ -82,6 +92,11 @@ def check_eq2_c(self, inputs, output, args_in, arg_out):
val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out))
def check_eq2_both(self, inputs, output, args_in, arg_out):
fn = Function(inputs, [output], linker_cls = lambda env: gof.DualLinker(env, _numpy_checker))
val = fn(*args_in)
self.failUnless( numpy.all(val == arg_out), (val, arg_out))
class T_argmax(unittest.TestCase):
def setUp(self):
numpy.random.seed(123784)
......@@ -408,24 +423,24 @@ class T_mul(unittest.TestCase):
def test_elemwise(self):
a = astensor(0.0)
b = astensor(0.0)
check_eq2(self, [a,b], mul_elemwise(a,b), [3.0, 4.0], 12.0)
check_eq2(self, [a,b], mul_elemwise(b,a), [-1.0,2.0], -2.0)
check_eq2_both(self, [a,b], mul_elemwise(a,b), [3.0, 4.0], 12.0)
check_eq2_both(self, [a,b], mul_elemwise(b,a), [-1.0,2.0], -2.0)
self.failUnless(isinstance(mul(a,b).owner, Scale))
a = astensor(numpy.ones(2))
b = astensor(numpy.ones(2))
aa = numpy.asarray([-0.5, 4.0])
bb = numpy.asarray([-0.5, 2.0])
check_eq2(self, [a,b], mul_elemwise(a,b), [aa,bb], numpy.asarray([0.25, 8.0]))
check_eq2(self, [a,b], mul_elemwise(a,b), [bb,aa], numpy.asarray([0.25, 8.0]))
check_eq2_both(self, [a,b], mul_elemwise(a,b), [aa,bb], numpy.asarray([0.25, 8.0]))
check_eq2_both(self, [a,b], mul_elemwise(a,b), [bb,aa], numpy.asarray([0.25, 8.0]))
self.failUnless(isinstance(mul(a,b).owner, MulElemwise))
def test_scalar(self):
r = numpy.random.rand(2,3)
a = astensor(r)
b = astensor(2.0)
check_eq2(self, [a,b], scale(a,b), [r, 2.0], r*2.0)
check_eq2(self, [a,b], scale(a,b), [r, 4.0], r*4.0)
check_eq2_both(self, [a,b], scale(a,b), [r, 2.0], r*2.0)
check_eq2_both(self, [a,b], scale(a,b), [r, 4.0], r*4.0)
self.failUnless(b.data == 2.0)
def test_operator(self):
......@@ -442,11 +457,17 @@ class T_mul(unittest.TestCase):
b = astensor(numpy.ones(4))
try:
check_eq2(self, [a,b], MulElemwise(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0)
[numpy.ones(3), numpy.ones(4)], 1.0)
self.fail()
except ValueError, e:
self.failUnless(e[0] is tensor._assert_same_shapes.E_shape)
return
self.fail()
try:
check_eq2_c(self, [a,b], MulElemwise(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0)
self.fail()
except ValueError, e:
pass
class T_div(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论