提交 9e669ec0 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

incorporating py.test

上级 7cd0a3b9
import unittest, os, sys, traceback, commands import unittest, os, sys, traceback, commands
sys.path[0:0] = [os.path.realpath("..")] theano_path = os.path.realpath("%s/.." % sys.path[0])
sys.path[0:0] = [theano_path]
def test_module(module_path, debugmode = False): def test_module(module_path, debugmode = False):
files = commands.getoutput("find %s -name test_*.py" % module_path) files = commands.getoutput("find %s -name _test_*.py" % module_path)
suite = None suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1 tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"): for file in files.split("\n"):
...@@ -25,11 +26,59 @@ def test_module(module_path, debugmode = False): ...@@ -25,11 +26,59 @@ def test_module(module_path, debugmode = False):
suite = tests suite = tests
else: else:
suite.addTests(tests) suite.addTests(tests)
if suite is None:
return
if debugmode: if debugmode:
suite.debug() suite.debug()
else: else:
unittest.TextTestRunner(verbosity=1).run(suite) unittest.TextTestRunner(verbosity=1).run(suite)
def py_test(module_path):
py.test.cmdline.main([module_path])
def nopy_test(module_path):
print >>sys.stderr, "py.test is not installed!"
print >>sys.stderr, " easy_install py"
print >>sys.stderr, "or if you are installing locally"
print >>sys.stderr, " easy_install --prefix=/some/local/dir py"
return None
files = commands.getoutput("find %s -name test_*.py" % module_path)
suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"):
file = file[tocut:]
try:
module = __import__(file[:-3])
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Failed to load %s" % file
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
continue
for method in dir(module):
if method.startswith("test_"):
method = getattr(module, method)
try:
r = method()
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Exception in %s.%s" % (file, method.__name__)
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
if hasattr(r, 'next'):
for fargs in r:
try:
fargs[0](*fargs[1:])
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Exception in %s.%s, %s%s" % (file, method.__name__, fargs[0], fargs[1:])
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
if __name__ == '__main__': if __name__ == '__main__':
...@@ -48,6 +97,12 @@ if __name__ == '__main__': ...@@ -48,6 +97,12 @@ if __name__ == '__main__':
elif len(sys.argv)>2: elif len(sys.argv)>2:
printUsage() printUsage()
test_module(os.path.realpath("../theano")) mname = os.path.join(theano_path, "theano")
test_module(mname)
try:
import py.test
py_test(mname)
except ImportError:
nopy_test(mname)
...@@ -135,77 +135,76 @@ def Env(inputs, outputs): ...@@ -135,77 +135,76 @@ def Env(inputs, outputs):
return e return e
class _test_CLinker(unittest.TestCase): ################
# Test CLinker #
################
def test_straightforward(self): def test_clinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y, z], [e])) lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) assert fn(2.0, 2.0, 2.0) == 2.0
# def test_orphan(self): def test_clinker_literal_inlining():
# x, y, z = inputs()
# z = Constant(tdouble, 4.12345678)
# e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
# lnk = CLinker(Env([x, y], [e]))
# fn = lnk.make_function()
# self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
# print lnk.code_gen()
# self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
def test_literal_inlining(self):
x, y, z = inputs() x, y, z = inputs()
z = Constant(tdouble, 4.12345678) z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y], [e])) lnk = CLinker().accept(Env([x, y], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined code = lnk.code_gen()
print "=== Code generated ==="
print code
assert "4.12345678" in code # we expect the number to be inlined
def test_single_node(self): def test_clinker_single_node():
x, y, z = inputs() x, y, z = inputs()
node = add.make_node(x, y) node = add.make_node(x, y)
lnk = CLinker().accept(Env(node.inputs, node.outputs)) lnk = CLinker().accept(Env(node.inputs, node.outputs))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9) assert fn(2.0, 7.0) == 9
def test_dups(self): def test_clinker_dups():
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
e = add(x, x) e = add(x, x)
lnk = CLinker().accept(Env([x, x], [e])) lnk = CLinker().accept(Env([x, x], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4) assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
def test_dups_inner(self): def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(Env([x, y, z], [e])) lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0) assert fn(1.0, 2.0, 3.0) == 8.0
class _test_OpWiseCLinker(unittest.TestCase): ######################
# Test OpWiseCLinker #
######################
def test_straightforward(self): def test_opwiseclinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e])) lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) assert fn(2.0, 2.0, 2.0) == 2.0
def test_constant(self): def test_opwiseclinker_constant():
x, y, z = inputs() x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x') x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z)) e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e])) lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(1.5, 3.0) res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res) assert res == 15.3
class MyExc(Exception): class MyExc(Exception):
...@@ -215,49 +214,34 @@ def _my_checker(x, y): ...@@ -215,49 +214,34 @@ def _my_checker(x, y):
raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]}) raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class _test_DualLinker(unittest.TestCase): ###################
# Test DualLinker #
###################
def test_straightforward(self): def test_duallinker_straightforward():
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e])) lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res) assert res == 15.3
def test_mismatch(self): def test_duallinker_mismatch():
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g) lnk = DualLinker(checker = _my_checker).accept(g)
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
self.failUnless(OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
self.failUnless(PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 # (purposely) wrong
try: try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds # this runs OpWiseCLinker and PerformLinker in parallel and feeds
# results of matching operations to _my_checker to verify that they # results of matching operations to _my_checker to verify that they
# are the same. # are the same.
res = fn(1.0, 2.0, 3.0) res = fn(1.0, 2.0, 3.0)
self.fail() raise Exception("An exception should have been raised here!")
except MyExc, e: except MyExc, e:
pass pass
else:
self.fail()
# def test_orphan(self):
# x, y, z = inputs()
# x = Constant(tdouble, 7.2, name = 'x')
# e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
# lnk = DualLinker(Env([y, z], [e]), checker = _my_checker)
# fn = lnk.make_function()
# res = fn(1.5, 3.0)
# self.failUnless(res == 15.3, res)
if __name__ == '__main__':
unittest.main()
from collections import deque from collections import deque
import unittest
from theano.gof.graph import * from theano.gof.graph import *
from theano.gof.op import Op from theano.gof.op import Op
...@@ -8,13 +7,6 @@ from theano.gof.type import Type ...@@ -8,13 +7,6 @@ from theano.gof.type import Type
from theano.gof.graph import Result from theano.gof.graph import Result
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
def as_result(x): def as_result(x):
assert isinstance(x, Result) assert isinstance(x, Result)
return x return x
...@@ -55,75 +47,31 @@ class MyOp(Op): ...@@ -55,75 +47,31 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
##########
# inputs #
##########
# class MyResult(Result): class TestInputs:
# def __init__(self, thingy):
# self.thingy = thingy
# Result.__init__(self, role = None )
# self.data = [self.thingy]
# def __eq__(self, other):
# return self.same_properties(other)
# def same_properties(self, other):
# return isinstance(other, MyResult) and other.thingy == self.thingy
# def __str__(self):
# return str(self.thingy)
# def __repr__(self):
# return str(self.thingy)
# class MyOp(Op): def test_inputs(self):
# def __init__(self, *inputs):
# for input in inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.inputs = inputs
# self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(testcase):
def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == [r1, r2] assert inputs(node.outputs) == [r1, r2]
def test_deep(self): def test_inputs_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
i = inputs(node2.outputs) i = inputs(node2.outputs)
self.failUnless(i == [r1, r2, r5], i) assert i == [r1, r2, r5], i
# def test_unreached_inputs(self):
# r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
# op = MyOp(r1, r2)
# op2 = MyOp(op.outputs[0], r5)
# try:
# # function doesn't raise if we put False instead of True
# ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
# except Exception, e:
# if e[0] is results_and_orphans.E_unreached:
# return
# self.fail()
#############
# as_string #
#############
# class _test_orphans(testcase):
# def test_straightforward(self): class X:
# r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
# node = MyOp.make_node(r1, r2)
# node2 = MyOp.make_node(node.outputs[0], r5)
# orph = orphans([r1, r2], node2.outputs)
# self.failUnless(orph == [r5], orph)
class _test_as_string(testcase):
leaf_formatter = lambda self, leaf: str(leaf.type) leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op, node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
...@@ -134,18 +82,21 @@ class _test_as_string(testcase): ...@@ -134,18 +82,21 @@ class _test_as_string(testcase):
leaf_formatter = self.leaf_formatter, leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter) node_formatter = self.node_formatter)
def test_straightforward(self):
class TestStr(X):
def test_as_string(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
s = self.str([r1, r2], node.outputs) s = self.str([r1, r2], node.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s) assert s == ["MyOp(R1, R2)"]
def test_deep(self): def test_as_string_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
s = self.str([r1, r2, r5], node2.outputs) s = self.str([r1, r2, r5], node2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s) assert s == ["MyOp(MyOp(R1, R2), R5)"]
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
...@@ -161,16 +112,11 @@ class _test_as_string(testcase): ...@@ -161,16 +112,11 @@ class _test_as_string(testcase):
assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"] assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(testcase): #########
# clone #
#########
leaf_formatter = lambda self, leaf: str(leaf.type) class TestClone(X):
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings))
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
...@@ -198,6 +144,11 @@ class _test_clone(testcase): ...@@ -198,6 +144,11 @@ class _test_clone(testcase):
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"] assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"] assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
############
# toposort #
############
def prenode(obj): def prenode(obj):
if isinstance(obj, Result): if isinstance(obj, Result):
if obj.owner: if obj.owner:
...@@ -205,79 +156,64 @@ def prenode(obj): ...@@ -205,79 +156,64 @@ def prenode(obj):
if isinstance(obj, Apply): if isinstance(obj, Apply):
return obj.inputs return obj.inputs
class _test_toposort(testcase): class TestToposort:
def test0(self):
def test_0(self):
"""Test a simple graph""" """Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r2) o = MyOp.make_node(r1, r2)
o2 = MyOp.make_node(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all) assert all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]]
all = io_toposort([r5], o2.outputs) all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all) assert all == [o, o2]
def test1(self): def test_1(self):
"""Test a graph with double dependencies""" """Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
o2 = MyOp.make_node(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all) assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]]
def test2(self): def test_2(self):
"""Test a graph where the inputs have owners""" """Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
r2b = o.outputs[0] r2b = o.outputs[0]
o2 = MyOp.make_node(r2b, r2b) o2 = MyOp.make_node(r2b, r2b)
all = io_toposort([r2b], o2.outputs) all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all) assert all == [o2]
o2 = MyOp.make_node(r2b, r5) o2 = MyOp.make_node(r2b, r5)
all = io_toposort([r2b], o2.outputs) all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all) assert all == [o2]
def test3(self): def test_3(self):
"""Test a graph which is not connected""" """Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(r3, r4) o1 = MyOp.make_node(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all) assert all == [o1,o0]
def test4(self): def test_4(self):
"""Test inputs and outputs mixed together in a chain graph""" """Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r1) o1 = MyOp.make_node(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all) assert all == [o1]
def test5(self): def test_5(self):
"""Test when outputs have clients""" """Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4) r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r4) o1 = MyOp.make_node(o0.outputs[0], r4)
all = io_toposort([], o0.outputs) all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all) assert all == [o0]
if __name__ == '__main__':
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
import unittest
from theano.gof import graph from theano.gof import graph
from theano.gof.graph import Result, Apply, Constant from theano.gof.graph import Result, Apply, Constant
from theano.gof.type import Type from theano.gof.type import Type
...@@ -78,7 +75,7 @@ def Env(inputs, outputs): ...@@ -78,7 +75,7 @@ def Env(inputs, outputs):
return e return e
class _test_PerformLinker(unittest.TestCase): class TestPerformLinker:
def test_thunk(self): def test_thunk(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -105,14 +102,14 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -105,14 +102,14 @@ class _test_PerformLinker(unittest.TestCase):
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
fn = perform_linker(Env([x], [x])).make_function() fn = perform_linker(Env([x], [x])).make_function()
self.failUnless(1.0 is fn(1.0)) assert 1.0 is fn(1.0)
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
self.failUnless(fn(1.0,2.0,9.0) == 4.5) assert fn(1.0,2.0,9.0) == 4.5
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x,y,z = inputs()
...@@ -120,14 +117,16 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -120,14 +117,16 @@ class _test_PerformLinker(unittest.TestCase):
r = raise_err(a) r = raise_err(a)
e = add(r,a) e = add(r,a)
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) assert fn(1.0,2.0,4.5) == 7.5
def wrap_linker(env, linkers, wrapper): def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env) lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self): class TestWrapLinker:
def test_0(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
...@@ -138,10 +137,10 @@ class _test_WrapLinker(unittest.TestCase): ...@@ -138,10 +137,10 @@ class _test_WrapLinker(unittest.TestCase):
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
self.failUnless(nodes == [div, add, mul], nodes) assert nodes == [div, add, mul]
self.failUnless(o[0].data is None) assert o[0].data is None
def test1(self): def test_1(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
...@@ -153,44 +152,8 @@ class _test_WrapLinker(unittest.TestCase): ...@@ -153,44 +152,8 @@ class _test_WrapLinker(unittest.TestCase):
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
self.failUnless(nodes == [div, add, mul], nodes) assert nodes == [div, add, mul]
self.failUnless(o[0].data == 1.5, o[0].data) assert o[0].data == 1.5
# def test_disconnected_input_output(self):
# x,y,z = inputs()
# a = add(x,y)
# a.data = 3.0 # simulate orphan calculation
# fn = perform_linker(env([z], [a])).make_function(inplace=True)
# self.failUnless(fn(1.0) == 3.0)
# self.failUnless(fn(2.0) == 3.0)
# def test_thunk_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk(True)
# fn()
# assert e.data == 1.5
# def test_thunk_not_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
# fn()
# assert o[0].data == 1.5
# assert e.data != 1.5
# def test_function(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn = perform_linker(env([x, y, z], [e])).make_function()
# assert fn(1.0, 2.0, 3.0) == 1.5
# assert e.data != 1.5 # not inplace
if __name__ == '__main__':
unittest.main()
......
import unittest
from copy import copy from copy import copy
from theano.gof.op import * from theano.gof.op import *
from theano.gof.type import Type, Generic from theano.gof.type import Type, Generic
...@@ -38,7 +37,7 @@ class MyOp(Op): ...@@ -38,7 +37,7 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
class _test_Op(unittest.TestCase): class TestOp:
# Sanity tests # Sanity tests
def test_sanity_0(self): def test_sanity_0(self):
......
import unittest
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.graph import Result, Apply, Constant from theano.gof.graph import Result, Apply, Constant
from theano.gof.op import Op from theano.gof.op import Op
...@@ -73,7 +71,7 @@ def inputs(): ...@@ -73,7 +71,7 @@ def inputs():
PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
class _test_PatternOptimizer(unittest.TestCase): class TestPatternOptimizer:
def test_replace_output(self): def test_replace_output(self):
# replacing the whole graph # replacing the whole graph
...@@ -250,7 +248,7 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -250,7 +248,7 @@ class _test_PatternOptimizer(unittest.TestCase):
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2)) OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2)) OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
class _test_OpSubOptimizer(unittest.TestCase): class TestOpSubOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -267,7 +265,7 @@ class _test_OpSubOptimizer(unittest.TestCase): ...@@ -267,7 +265,7 @@ class _test_OpSubOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]" assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
class _test_MergeOptimizer(unittest.TestCase): class TestMergeOptimizer:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -330,42 +328,6 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -330,42 +328,6 @@ class _test_MergeOptimizer(unittest.TestCase):
g = Env([x, y, z], [e1]) g = Env([x, y, z], [e1])
MergeOptimizer().optimize(g) MergeOptimizer().optimize(g)
strg = str(g) strg = str(g)
self.failUnless(strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]', strg) assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
# def test_identical_constant_args_with_destroymap(self):
# x, y, z = inputs()
# y.data = 2.0
# y.constant = False
# z.data = 2.0
# z.constant = True
# e1 = op_d(y, z)
# g = env([x, y, z], [e1])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[OpD(y, z)]', strg)
# def test_merge_with_destroyer_1(self):
# x, y, z = inputs()
# e1 = op_d(op1(x,y), y)
# e2 = op_d(op1(x,y), z)
# g = env([x, y, z], [e1,e2])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[OpD(Op1(x, y), y), OpD(Op1(x, y), z)]', strg)
# def test_merge_with_destroyer_2(self):
# x, y, z = inputs()
# e1 = op_d(op1(x,y), z)
# e2 = op_d(op1(x,y), z)
# g = env([x, y, z], [e1,e2])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[*1 -> OpD(Op1(x, y), z), *1]', strg)
if __name__ == '__main__':
unittest.main()
import unittest
from theano.gof.graph import Result, Apply from theano.gof.graph import Result, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
...@@ -63,22 +61,8 @@ def inputs(): ...@@ -63,22 +61,8 @@ def inputs():
return x, y, z return x, y, z
# class _test_EquivTool(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# sx = sigmoid(x)
# e = add(sx, sigmoid(y))
# g = Env([x, y, z], [e])
# g.extend(EquivTool(g))
# assert hasattr(g, 'equiv')
# assert g.equiv(sx) is sx
# g.replace(sx, dot(x, z))
# assert g.equiv(sx) is not sx
# assert g.equiv(sx).owner.op is dot
class _test_NodeFinder(unittest.TestCase): class TestNodeFinder:
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -89,7 +73,7 @@ class _test_NodeFinder(unittest.TestCase): ...@@ -89,7 +73,7 @@ class _test_NodeFinder(unittest.TestCase):
assert hasattr(g, 'get_nodes') assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num: if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num)) raise Exception("Expected: %i times %s" % (num, type))
new_e0 = add(y, z) new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot) assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_nodes(add) assert new_e0.owner not in g.get_nodes(add)
...@@ -98,21 +82,7 @@ class _test_NodeFinder(unittest.TestCase): ...@@ -98,21 +82,7 @@ class _test_NodeFinder(unittest.TestCase):
assert new_e0.owner in g.get_nodes(add) assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([x for x in g.get_nodes(type)]) == num: if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num)) raise Exception("Expected: %i times %s" % (num, type))
# def test_robustness(self):
# # this test used to make sense to have, but it doesn't work like that anymore
# x, y, z = inputs()
# e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
# g = Env([x, y, z], [e])
# g.extend(NodeFinder())
# gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
# g.replace(e, add(x, y)) # but here I prune them all
# assert len([x for x in gen]) == 0 # the generator should not yield them
if __name__ == '__main__':
unittest.main()
import unittest
from theano.gof.type import * from theano.gof.type import *
# todo: test generic # todo: test generic
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论