提交 9e669ec0 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

incorporating py.test

上级 7cd0a3b9
import unittest, os, sys, traceback, commands
sys.path[0:0] = [os.path.realpath("..")]
theano_path = os.path.realpath("%s/.." % sys.path[0])
sys.path[0:0] = [theano_path]
def test_module(module_path, debugmode = False):
files = commands.getoutput("find %s -name test_*.py" % module_path)
files = commands.getoutput("find %s -name _test_*.py" % module_path)
suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"):
......@@ -17,7 +18,7 @@ def test_module(module_path, debugmode = False):
traceback.print_exc()
print >>sys.stderr, "===================================================="
continue
tests = unittest.TestLoader().loadTestsFromModule(module)
if tests.countTestCases() > 0:
print >>sys.stderr, 'Testing', file
......@@ -25,11 +26,59 @@ def test_module(module_path, debugmode = False):
suite = tests
else:
suite.addTests(tests)
if suite is None:
return
if debugmode:
suite.debug()
else:
unittest.TextTestRunner(verbosity=1).run(suite)
def py_test(module_path):
py.test.cmdline.main([module_path])
def nopy_test(module_path):
print >>sys.stderr, "py.test is not installed!"
print >>sys.stderr, " easy_install py"
print >>sys.stderr, "or if you are installing locally"
print >>sys.stderr, " easy_install --prefix=/some/local/dir py"
return None
files = commands.getoutput("find %s -name test_*.py" % module_path)
suite = None
tocut = len("/".join(module_path.split("/")[:-1])) + 1
for file in files.split("\n"):
file = file[tocut:]
try:
module = __import__(file[:-3])
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Failed to load %s" % file
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
continue
for method in dir(module):
if method.startswith("test_"):
method = getattr(module, method)
try:
r = method()
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Exception in %s.%s" % (file, method.__name__)
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
if hasattr(r, 'next'):
for fargs in r:
try:
fargs[0](*fargs[1:])
except Exception, e:
print >>sys.stderr, "===================================================="
print >>sys.stderr, "Exception in %s.%s, %s%s" % (file, method.__name__, fargs[0], fargs[1:])
print >>sys.stderr, "===================================================="
traceback.print_exc()
print >>sys.stderr, "===================================================="
if __name__ == '__main__':
......@@ -48,6 +97,12 @@ if __name__ == '__main__':
elif len(sys.argv)>2:
printUsage()
test_module(os.path.realpath("../theano"))
mname = os.path.join(theano_path, "theano")
test_module(mname)
try:
import py.test
py_test(mname)
except ImportError:
nopy_test(mname)
......@@ -135,77 +135,76 @@ def Env(inputs, outputs):
return e
class _test_CLinker(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
# def test_orphan(self):
# x, y, z = inputs()
# z = Constant(tdouble, 4.12345678)
# e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
# lnk = CLinker(Env([x, y], [e]))
# fn = lnk.make_function()
# self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
# print lnk.code_gen()
# self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
def test_literal_inlining(self):
x, y, z = inputs()
z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y], [e]))
fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
def test_single_node(self):
x, y, z = inputs()
node = add.make_node(x, y)
lnk = CLinker().accept(Env(node.inputs, node.outputs))
fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9)
def test_dups(self):
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, x)
lnk = CLinker().accept(Env([x, x], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_dups_inner(self):
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
################
# Test CLinker #
################
def test_clinker_straightforward():
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0
def test_clinker_literal_inlining():
x, y, z = inputs()
z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker().accept(Env([x, y], [e]))
fn = lnk.make_function()
assert abs(fn(2.0, 2.0) + 0.12345678) < 1e-9
code = lnk.code_gen()
print "=== Code generated ==="
print code
assert "4.12345678" in code # we expect the number to be inlined
def test_clinker_single_node():
x, y, z = inputs()
node = add.make_node(x, y)
lnk = CLinker().accept(Env(node.inputs, node.outputs))
fn = lnk.make_function()
assert fn(2.0, 7.0) == 9
def test_clinker_dups():
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
e = add(x, x)
lnk = CLinker().accept(Env([x, x], [e]))
fn = lnk.make_function()
assert fn(2.0, 2.0) == 4
# note: for now the behavior of fn(2.0, 7.0) is undefined
def test_clinker_dups_inner():
# Testing that duplicates are allowed inside the graph
x, y, z = inputs()
e = add(mul(y, y), add(x, z))
lnk = CLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(1.0, 2.0, 3.0) == 8.0
class _test_OpWiseCLinker(unittest.TestCase):
######################
# Test OpWiseCLinker #
######################
def test_opwiseclinker_straightforward():
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
assert fn(2.0, 2.0, 2.0) == 2.0
def test_opwiseclinker_constant():
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
assert res == 15.3
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker().accept(Env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_constant(self):
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker().accept(Env([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res)
class MyExc(Exception):
......@@ -215,49 +214,34 @@ def _my_checker(x, y):
raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class _test_DualLinker(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res)
def test_mismatch(self):
x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g)
fn = lnk.make_function()
self.failUnless(CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0) # good
self.failUnless(PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0) # (purposely) wrong
try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds
# results of matching operations to _my_checker to verify that they
# are the same.
res = fn(1.0, 2.0, 3.0)
self.fail()
except MyExc, e:
pass
else:
self.fail()
# def test_orphan(self):
# x, y, z = inputs()
# x = Constant(tdouble, 7.2, name = 'x')
# e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
# lnk = DualLinker(Env([y, z], [e]), checker = _my_checker)
# fn = lnk.make_function()
# res = fn(1.5, 3.0)
# self.failUnless(res == 15.3, res)
if __name__ == '__main__':
unittest.main()
###################
# Test DualLinker #
###################
def test_duallinker_straightforward():
x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(checker = _my_checker).accept(Env([x, y, z], [e]))
fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0)
assert res == 15.3
def test_duallinker_mismatch():
x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = Env([x, y, z], [e])
lnk = DualLinker(checker = _my_checker).accept(g)
fn = lnk.make_function()
assert CLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
assert OpWiseCLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -4.0 # good
assert PerformLinker().accept(g).make_function()(1.0, 2.0, 3.0) == -10.0 # (purposely) wrong
try:
# this runs OpWiseCLinker and PerformLinker in parallel and feeds
# results of matching operations to _my_checker to verify that they
# are the same.
res = fn(1.0, 2.0, 3.0)
raise Exception("An exception should have been raised here!")
except MyExc, e:
pass
from collections import deque
import unittest
from theano.gof.graph import *
from theano.gof.op import Op
......@@ -8,13 +7,6 @@ from theano.gof.type import Type
from theano.gof.graph import Result
if 1:
testcase = unittest.TestCase
else:
testcase = object
realtestcase = unittest.TestCase
def as_result(x):
assert isinstance(x, Result)
return x
......@@ -55,75 +47,31 @@ class MyOp(Op):
MyOp = MyOp()
##########
# inputs #
##########
# class MyResult(Result):
# def __init__(self, thingy):
# self.thingy = thingy
# Result.__init__(self, role = None )
# self.data = [self.thingy]
# def __eq__(self, other):
# return self.same_properties(other)
# def same_properties(self, other):
# return isinstance(other, MyResult) and other.thingy == self.thingy
# def __str__(self):
# return str(self.thingy)
# def __repr__(self):
# return str(self.thingy)
class TestInputs:
# class MyOp(Op):
# def __init__(self, *inputs):
# for input in inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.inputs = inputs
# self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(testcase):
def test_straightforward(self):
def test_inputs(self):
r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == [r1, r2]
def test_deep(self):
def test_inputs_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
i = inputs(node2.outputs)
self.failUnless(i == [r1, r2, r5], i)
# def test_unreached_inputs(self):
# r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
# op = MyOp(r1, r2)
# op2 = MyOp(op.outputs[0], r5)
# try:
# # function doesn't raise if we put False instead of True
# ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
# except Exception, e:
# if e[0] is results_and_orphans.E_unreached:
# return
# self.fail()
# class _test_orphans(testcase):
# def test_straightforward(self):
# r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
# node = MyOp.make_node(r1, r2)
# node2 = MyOp.make_node(node.outputs[0], r5)
# orph = orphans([r1, r2], node2.outputs)
# self.failUnless(orph == [r5], orph)
assert i == [r1, r2, r5], i
#############
# as_string #
#############
class _test_as_string(testcase):
class X:
leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
......@@ -133,19 +81,22 @@ class _test_as_string(testcase):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
class TestStr(X):
def test_straightforward(self):
def test_as_string(self):
r1, r2 = MyResult(1), MyResult(2)
node = MyOp.make_node(r1, r2)
s = self.str([r1, r2], node.outputs)
self.failUnless(s == ["MyOp(R1, R2)"], s)
assert s == ["MyOp(R1, R2)"]
def test_deep(self):
def test_as_string_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
s = self.str([r1, r2, r5], node2.outputs)
self.failUnless(s == ["MyOp(MyOp(R1, R2), R5)"], s)
assert s == ["MyOp(MyOp(R1, R2), R5)"]
def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
......@@ -161,16 +112,11 @@ class _test_as_string(testcase):
assert self.str(node2.inputs, node2.outputs) == ["MyOp(R3, R3)"]
class _test_clone(testcase):
leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings))
#########
# clone #
#########
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
class TestClone(X):
def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2)
......@@ -198,6 +144,11 @@ class _test_clone(testcase):
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(R1, R2), R5)"]
############
# toposort #
############
def prenode(obj):
if isinstance(obj, Result):
if obj.owner:
......@@ -205,79 +156,64 @@ def prenode(obj):
if isinstance(obj, Apply):
return obj.inputs
class _test_toposort(testcase):
def test0(self):
class TestToposort:
def test_0(self):
"""Test a simple graph"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r2)
o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
assert all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]]
all = io_toposort([r5], o2.outputs)
self.failUnless(all == [o, o2], all)
assert all == [o, o2]
def test1(self):
def test_1(self):
"""Test a graph with double dependencies"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r1)
o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
self.failUnless(all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]], all)
assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]]
def test2(self):
def test_2(self):
"""Test a graph where the inputs have owners"""
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
o = MyOp.make_node(r1, r1)
r2b = o.outputs[0]
o2 = MyOp.make_node(r2b, r2b)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
assert all == [o2]
o2 = MyOp.make_node(r2b, r5)
all = io_toposort([r2b], o2.outputs)
self.failUnless(all == [o2], all)
assert all == [o2]
def test3(self):
def test_3(self):
"""Test a graph which is not connected"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(r3, r4)
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
self.failUnless(all == [o1,o0], all)
assert all == [o1,o0]
def test4(self):
def test_4(self):
"""Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
self.failUnless(all == [o1], all)
assert all == [o1]
def test5(self):
def test_5(self):
"""Test when outputs have clients"""
r1, r2, r3, r4 = MyResult(1), MyResult(2), MyResult(3), MyResult(4)
o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r4)
all = io_toposort([], o0.outputs)
self.failUnless(all == [o0], all)
if __name__ == '__main__':
if 1:
#run all tests
unittest.main()
elif 1:
#load some TestCase classes
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(_test_toposort)
assert all == [o0]
#run just some of them
unittest.TextTestRunner(verbosity=2).run(suite)
else:
#run just a single test
_test_toposort('test0').debug()
import unittest
from theano.gof import graph
from theano.gof.graph import Result, Apply, Constant
from theano.gof.type import Type
......@@ -78,7 +75,7 @@ def Env(inputs, outputs):
return e
class _test_PerformLinker(unittest.TestCase):
class TestPerformLinker:
def test_thunk(self):
x, y, z = inputs()
......@@ -105,14 +102,14 @@ class _test_PerformLinker(unittest.TestCase):
def test_input_output_same(self):
x, y, z = inputs()
fn = perform_linker(Env([x], [x])).make_function()
self.failUnless(1.0 is fn(1.0))
assert 1.0 is fn(1.0)
def test_input_dependency0(self):
x, y, z = inputs()
a,d = add(x,y), div(x,y)
e = mul(a,d)
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
self.failUnless(fn(1.0,2.0,9.0) == 4.5)
assert fn(1.0,2.0,9.0) == 4.5
def test_skiphole(self):
x,y,z = inputs()
......@@ -120,14 +117,16 @@ class _test_PerformLinker(unittest.TestCase):
r = raise_err(a)
e = add(r,a)
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5)
assert fn(1.0,2.0,4.5) == 7.5
def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk
class _test_WrapLinker(unittest.TestCase):
def test0(self):
class TestWrapLinker:
def test_0(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
......@@ -138,10 +137,10 @@ class _test_WrapLinker(unittest.TestCase):
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data is None)
assert nodes == [div, add, mul]
assert o[0].data is None
def test1(self):
def test_1(self):
nodes = []
def wrap(i, node, th):
nodes.append(node.op)
......@@ -153,44 +152,8 @@ class _test_WrapLinker(unittest.TestCase):
i[0].data = 1
i[1].data = 2
fn()
self.failUnless(nodes == [div, add, mul], nodes)
self.failUnless(o[0].data == 1.5, o[0].data)
# def test_disconnected_input_output(self):
# x,y,z = inputs()
# a = add(x,y)
# a.data = 3.0 # simulate orphan calculation
# fn = perform_linker(env([z], [a])).make_function(inplace=True)
# self.failUnless(fn(1.0) == 3.0)
# self.failUnless(fn(2.0) == 3.0)
# def test_thunk_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk(True)
# fn()
# assert e.data == 1.5
# def test_thunk_not_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
# fn()
# assert o[0].data == 1.5
# assert e.data != 1.5
# def test_function(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn = perform_linker(env([x, y, z], [e])).make_function()
# assert fn(1.0, 2.0, 3.0) == 1.5
# assert e.data != 1.5 # not inplace
if __name__ == '__main__':
unittest.main()
assert nodes == [div, add, mul]
assert o[0].data == 1.5
......
import unittest
from copy import copy
from theano.gof.op import *
from theano.gof.type import Type, Generic
......@@ -38,7 +37,7 @@ class MyOp(Op):
MyOp = MyOp()
class _test_Op(unittest.TestCase):
class TestOp:
# Sanity tests
def test_sanity_0(self):
......
import unittest
from theano.gof.type import Type
from theano.gof.graph import Result, Apply, Constant
from theano.gof.op import Op
......@@ -73,7 +71,7 @@ def inputs():
PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
class _test_PatternOptimizer(unittest.TestCase):
class TestPatternOptimizer:
def test_replace_output(self):
# replacing the whole graph
......@@ -250,7 +248,7 @@ class _test_PatternOptimizer(unittest.TestCase):
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
class _test_OpSubOptimizer(unittest.TestCase):
class TestOpSubOptimizer:
def test_straightforward(self):
x, y, z = inputs()
......@@ -267,7 +265,7 @@ class _test_OpSubOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op2(x), Op4(y), Op4(z))]"
class _test_MergeOptimizer(unittest.TestCase):
class TestMergeOptimizer:
def test_straightforward(self):
x, y, z = inputs()
......@@ -330,42 +328,6 @@ class _test_MergeOptimizer(unittest.TestCase):
g = Env([x, y, z], [e1])
MergeOptimizer().optimize(g)
strg = str(g)
self.failUnless(strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]', strg)
# def test_identical_constant_args_with_destroymap(self):
# x, y, z = inputs()
# y.data = 2.0
# y.constant = False
# z.data = 2.0
# z.constant = True
# e1 = op_d(y, z)
# g = env([x, y, z], [e1])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[OpD(y, z)]', strg)
# def test_merge_with_destroyer_1(self):
# x, y, z = inputs()
# e1 = op_d(op1(x,y), y)
# e2 = op_d(op1(x,y), z)
# g = env([x, y, z], [e1,e2])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[OpD(Op1(x, y), y), OpD(Op1(x, y), z)]', strg)
# def test_merge_with_destroyer_2(self):
# x, y, z = inputs()
# e1 = op_d(op1(x,y), z)
# e2 = op_d(op1(x,y), z)
# g = env([x, y, z], [e1,e2])
# MergeOptimizer().optimize(g)
# strg = str(g)
# self.failUnless(strg == '[*1 -> OpD(Op1(x, y), z), *1]', strg)
if __name__ == '__main__':
unittest.main()
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
import unittest
from theano.gof.graph import Result, Apply
from theano.gof.type import Type
from theano.gof.op import Op
......@@ -63,22 +61,8 @@ def inputs():
return x, y, z
# class _test_EquivTool(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# sx = sigmoid(x)
# e = add(sx, sigmoid(y))
# g = Env([x, y, z], [e])
# g.extend(EquivTool(g))
# assert hasattr(g, 'equiv')
# assert g.equiv(sx) is sx
# g.replace(sx, dot(x, z))
# assert g.equiv(sx) is not sx
# assert g.equiv(sx).owner.op is dot
class _test_NodeFinder(unittest.TestCase):
class TestNodeFinder:
def test_straightforward(self):
x, y, z = inputs()
......@@ -89,7 +73,7 @@ class _test_NodeFinder(unittest.TestCase):
assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num))
raise Exception("Expected: %i times %s" % (num, type))
new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_nodes(add)
......@@ -98,21 +82,7 @@ class _test_NodeFinder(unittest.TestCase):
assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([x for x in g.get_nodes(type)]) == num:
self.fail((type, num))
# def test_robustness(self):
# # this test used to make sense to have, but it doesn't work like that anymore
# x, y, z = inputs()
# e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
# g = Env([x, y, z], [e])
# g.extend(NodeFinder())
# gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
# g.replace(e, add(x, y)) # but here I prune them all
# assert len([x for x in gen]) == 0 # the generator should not yield them
raise Exception("Expected: %i times %s" % (num, type))
if __name__ == '__main__':
unittest.main()
import unittest
from theano.gof.type import *
# todo: test generic
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论