提交 2fd6ea22 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

switching from Op/Result to Op+Apply/Type+Result

上级 df258db8
import gof import gof
import tensor import tensor
import sparse #import sparse
import compile import compile
import gradient import gradient
import tensor_opt #import tensor_opt
import scalar_opt import scalar_opt
from tensor import *
from compile import * from compile import *
from tensor_opt import * #from tensor_opt import *
from scalar_opt import * from scalar_opt import *
from gradient import * from gradient import *
from tensor import *
...@@ -11,14 +11,14 @@ import tensor ...@@ -11,14 +11,14 @@ import tensor
from elemwise import * from elemwise import *
def inputs(): # def inputs():
x = modes.build(Tensor('float64', (0, 0), name = 'x')) # x = modes.build(Tensor('float64', (0, 0), name = 'x'))
y = modes.build(Tensor('float64', (1, 0), name = 'y')) # y = modes.build(Tensor('float64', (1, 0), name = 'y'))
z = modes.build(Tensor('float64', (0, 0), name = 'z')) # z = modes.build(Tensor('float64', (0, 0), name = 'z'))
return x, y, z # return x, y, z
def env(inputs, outputs, validate = True, features = []): # def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate) # return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_DimShuffle(unittest.TestCase): class _test_DimShuffle(unittest.TestCase):
...@@ -31,10 +31,11 @@ class _test_DimShuffle(unittest.TestCase): ...@@ -31,10 +31,11 @@ class _test_DimShuffle(unittest.TestCase):
((2, 3, 4), ('x', 2, 1, 0, 'x'), (1, 4, 3, 2, 1)), ((2, 3, 4), ('x', 2, 1, 0, 'x'), (1, 4, 3, 2, 1)),
((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)), ((1, 4, 3, 2, 1), (3, 2, 1), (2, 3, 4)),
((1, 1, 4), (1, 2), (1, 4))]: ((1, 1, 4), (1, 2), (1, 4))]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) ib = [(entry == 1) for entry in xsh]
e = DimShuffle(x, shuffle).out x = Tensor('float64', ib)('x')
e = DimShuffle(ib, shuffle)(x)
# print shuffle, e.owner.grad(e.owner.inputs, e.owner.outputs).owner.new_order # print shuffle, e.owner.grad(e.owner.inputs, e.owner.outputs).owner.new_order
f = linker(env([x], [e])).make_function(inplace=False) f = linker(Env([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh assert f(numpy.ones(xsh)).shape == zsh
def test_perform(self): def test_perform(self):
...@@ -53,10 +54,10 @@ class _test_Broadcast(unittest.TestCase): ...@@ -53,10 +54,10 @@ class _test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = modes.build(Tensor('float64', [1 * (entry == 1) for entry in ysh], name = 'y')) y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Broadcast(Add, (x, y)).out e = Elemwise(add)(x, y)
f = linker(env([x, y], [e])).make_function(inplace = False) f = linker(Env([x, y], [e])).make_function()
# xv = numpy.array(range(numpy.product(xsh))) # xv = numpy.array(range(numpy.product(xsh)))
# xv = xv.reshape(xsh) # xv = xv.reshape(xsh)
# yv = numpy.array(range(numpy.product(ysh))) # yv = numpy.array(range(numpy.product(ysh)))
...@@ -80,10 +81,10 @@ class _test_Broadcast(unittest.TestCase): ...@@ -80,10 +81,10 @@ class _test_Broadcast(unittest.TestCase):
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
y = modes.build(Tensor('float64', [1 * (entry == 1) for entry in ysh], name = 'y')) y = Tensor('float64', [(entry == 1) for entry in ysh])('y')
e = Broadcast(Add, (x, y), {0:0}).out e = Elemwise(Add(transfer_type(0)), {0:0})(x, y)
f = linker(env([x, y], [e])).make_function(inplace = False) f = linker(Env([x, y], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
yv = numpy.asarray(numpy.random.rand(*ysh)) yv = numpy.asarray(numpy.random.rand(*ysh))
zv = xv + yv zv = xv + yv
...@@ -105,29 +106,29 @@ class _test_Broadcast(unittest.TestCase): ...@@ -105,29 +106,29 @@ class _test_Broadcast(unittest.TestCase):
self.with_linker_inplace(gof.CLinker) self.with_linker_inplace(gof.CLinker)
def test_fill(self): def test_fill(self):
x = modes.build(Tensor('float64', [0, 0], name = 'x')) x = Tensor('float64', [0, 0])('x')
y = modes.build(Tensor('float64', [1, 1], name = 'y')) y = Tensor('float64', [1, 1])('y')
e = Broadcast(Second, (x, y), {0:0}).out e = Elemwise(Second(transfer_type(0)), {0:0})(x, y)
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False) f = gof.CLinker(Env([x, y], [e])).make_function()
xv = numpy.ones((5, 5)) xv = numpy.ones((5, 5))
yv = numpy.random.rand(1, 1) yv = numpy.random.rand(1, 1)
f(xv, yv) f(xv, yv)
assert (xv == yv).all() assert (xv == yv).all()
def test_weird_strides(self): def test_weird_strides(self):
x = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'x')) x = Tensor('float64', [0, 0, 0, 0, 0])('x')
y = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'y')) y = Tensor('float64', [0, 0, 0, 0, 0])('y')
e = Broadcast(Add, (x, y)).out e = Elemwise(add)(x, y)
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False) f = gof.CLinker(Env([x, y], [e])).make_function()
xv = numpy.random.rand(2, 2, 2, 2, 2) xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2) yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv zv = xv + yv
assert (f(xv, yv) == zv).all() assert (f(xv, yv) == zv).all()
def test_same_inputs(self): def test_same_inputs(self):
x = modes.build(Tensor('float64', [0, 0], name = 'x')) x = Tensor('float64', [0, 0])('x')
e = Broadcast(Add, (x, x)).out e = Elemwise(add)(x, x)
f = gof.CLinker(env([x], [e])).make_function(inplace = False) f = gof.CLinker(Env([x], [e])).make_function()
xv = numpy.random.rand(2, 2) xv = numpy.random.rand(2, 2)
zv = xv + xv zv = xv + xv
assert (f(xv) == zv).all() assert (f(xv) == zv).all()
...@@ -136,15 +137,17 @@ class _test_Broadcast(unittest.TestCase): ...@@ -136,15 +137,17 @@ class _test_Broadcast(unittest.TestCase):
class _test_CAReduce(unittest.TestCase): class _test_CAReduce(unittest.TestCase):
def with_linker(self, linker): def with_linker(self, linker):
for xsh, tosum in [((5, 6), (0, 1)), for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )), ((5, 6), (0, )),
((5, 6), (1, )), ((5, 6), (1, )),
((5, 6), ()), ((5, 6), ()),
((2, 3, 4, 5), (0, 1, 3)), ((2, 3, 4, 5), (0, 1, 3)),
((), ())]: ((), ())]:
x = modes.build(Tensor('float64', [1 * (entry == 1) for entry in xsh], name = 'x')) x = Tensor('float64', [(entry == 1) for entry in xsh])('x')
e = CAReduce(Add, [x], axis = tosum).out e = CAReduce(add, axis = tosum)(x)
f = linker(env([x], [e])).make_function(inplace = False) if tosum is None: tosum = range(len(xsh))
f = linker(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
zv = xv zv = xv
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
......
import unittest import unittest
from gof import Result, Op, Env, modes from gof import Result, Op, Env
import gof import gof
from scalar import * from scalar import *
def inputs(): def inputs():
x = modes.build(as_scalar(1.0, 'x')) return floats('xyz')
y = modes.build(as_scalar(2.0, 'y'))
z = modes.build(as_scalar(3.0, 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_ScalarOps(unittest.TestCase): class _test_ScalarOps(unittest.TestCase):
...@@ -22,7 +16,7 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -22,7 +16,7 @@ class _test_ScalarOps(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
g = env([x, y], [e]) g = Env([x, y], [e])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
...@@ -32,25 +26,21 @@ class _test_composite(unittest.TestCase): ...@@ -32,25 +26,21 @@ class _test_composite(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
C = composite([x, y], [e]) C = Composite([x, y], [e])
c = C(x, y) c = C.make_node(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform() g = Env([x, y], [c.out])
assert c.outputs[0].data == 1.5
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
def test_with_constants(self): def test_with_constants(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(70.0, y), div(x, y)) e = mul(add(70.0, y), div(x, y))
C = composite([x, y], [e]) C = Composite([x, y], [e])
c = C(x, y) c = C.make_node(x, y)
assert "70.0" in c.c_code(['x', 'y'], ['z'], dict(id = 0)) assert "70.0" in c.op.c_code(c, 'dummy', ['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0)) # print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform() g = Env([x, y], [c.out])
assert c.outputs[0].data == 36.0
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 36.0 assert fn(1.0, 2.0) == 36.0
...@@ -59,14 +49,10 @@ class _test_composite(unittest.TestCase): ...@@ -59,14 +49,10 @@ class _test_composite(unittest.TestCase):
e0 = x + y + z e0 = x + y + z
e1 = x + y * z e1 = x + y * z
e2 = x / y e2 = x / y
C = composite([x, y, z], [e0, e1, e2]) C = Composite([x, y, z], [e0, e1, e2])
c = C(x, y, z) c = C.make_node(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0)) # print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c.perform() g = Env([x, y, z], c.outputs)
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y, z], c.outputs)
fn = gof.DualLinker(g).make_function() fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
......
...@@ -10,53 +10,43 @@ from scalar_opt import * ...@@ -10,53 +10,43 @@ from scalar_opt import *
def inputs(): def inputs():
x = Scalar('float64', name = 'x') return floats('xyz')
y = Scalar('float64', name = 'y')
z = Scalar('float64', name = 'z')
a = Scalar('float64', name = 'a')
return x, y, z
def more_inputs(): def more_inputs():
a = Scalar('float64', name = 'a') return floats('abcd')
b = Scalar('float64', name = 'b')
c = Scalar('float64', name = 'c')
d = Scalar('float64', name = 'd')
return a, b, c, d
class _test_opts(unittest.TestCase): class _test_opts(unittest.TestCase):
def test_pow_to_sqr(self): def test_pow_to_sqr(self):
x, y, z = inputs() x, y, z = floats('xyz')
e = x ** 2.0 e = x ** 2.0
g = Env([x], [e]) g = Env([x], [e])
assert str(g) == "[Pow(x, 2.0)]" assert str(g) == "[pow(x, 2.0)]"
gof.ConstantFinder().optimize(g)
pow2sqr_float.optimize(g) pow2sqr_float.optimize(g)
assert str(g) == "[Sqr(x)]" assert str(g) == "[sqr(x)]"
# class _test_canonize(unittest.TestCase): class _test_canonize(unittest.TestCase):
# def test_muldiv(self): # def test_muldiv(self):
# x, y, z = inputs() # x, y, z = inputs()
# a, b, c, d = more_inputs() # a, b, c, d = more_inputs()
# # e = (2.0 * x) / (2.0 * y) # # e = (2.0 * x) / (2.0 * y)
# e = (2.0 * x) / (4.0 * y) # # e = (2.0 * x) / (4.0 * y)
# # e = x / (y / z) # # e = x / (y / z)
# # e = (x * y) / x # # e = (x * y) / x
# # e = (x / y) * (y / z) * (z / x) # # e = (x / y) * (y / z) * (z / x)
# # e = (a / b) * (b / c) * (c / d) # # e = (a / b) * (b / c) * (c / d)
# # e = (a * b) / (b * c) / (c * d) # # e = (a * b) / (b * c) / (c * d)
# # e = 2 * x / 2 # # e = 2 * x / 2
# # e = x / y / x # e = x / y / x
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # print g
# gof.ConstantFinder().optimize(g)
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y # divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x # invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn).optimize(g) # Canonizer(mul, div, inv, mulfn, divfn, invfn).optimize(g)
# print g # print g
# def test_plusmin(self): # def test_plusmin(self):
...@@ -101,35 +91,53 @@ class _test_opts(unittest.TestCase): ...@@ -101,35 +91,53 @@ class _test_opts(unittest.TestCase):
# print g # print g
# def test_group_powers(self): # def test_group_powers(self):
# x, y, z = inputs() # x, y, z, a, b, c, d = floats('xyzabcd')
# a, b, c, d = more_inputs()
###################
# c1, c2 = constant(1.), constant(2.)
# #e = pow(x, c1) * pow(x, y) / pow(x, 7.0) # <-- fucked
# #f = -- moving from div(mul.out, pow.out) to pow(x, sub.out)
# e = div(mul(pow(x, 2.0), pow(x, y)), pow(x, 7.0))
# g = Env([x, y, z, a, b, c, d], [e])
# print g
# print g.inputs, g.outputs, g.orphans
# f = sub(add(2.0, y), add(7.0))
# g.replace(e, pow(x, f))
# print g
# print g.inputs, g.outputs, g.orphans
# g.replace(f, sub(add(2.0, y), add(7.0))) # -- moving from sub(add.out, add.out) to sub(add.out, add.out)
# print g
# print g.inputs, g.outputs, g.orphans
###################
# # e = x * exp(y) * exp(z) # # e = x * exp(y) * exp(z)
# # e = x * pow(x, y) * pow(x, z) # # e = x * pow(x, y) * pow(x, z)
# # e = pow(x, y) / pow(x, z) # # e = pow(x, y) / pow(x, z)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) # <-- fucked
# # e = pow(x - x, y) # # e = pow(x - x, y)
# # e = pow(x, 2.0 + y - 7.0) # # e = pow(x, 2.0 + y - 7.0)
# # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z) # # e = pow(x, 2.0) * pow(x, y) / pow(x, 7.0) / pow(x, z)
# # e = pow(x, 2.0 + y - 7.0 - z) # # e = pow(x, 2.0 + y - 7.0 - z)
# # e = x ** y / x ** y # # e = x ** y / x ** y
# # e = x ** y / x ** (y - 1.0) # # e = x ** y / x ** (y - 1.0)
# e = exp(x) * a * exp(y) / exp(z) # # e = exp(x) * a * exp(y) / exp(z)
# g = Env([x, y, z, a, b, c, d], [e]) # g = Env([x, y, z, a, b, c, d], [e])
# print g # g.extend(gof.PrintListener(g))
# gof.ConstantFinder().optimize(g) # print g, g.orphans
# mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs) # mulfn = lambda *inputs: reduce(lambda x, y: x * y, (1,) + inputs)
# divfn = lambda x, y: x / y # divfn = lambda x, y: x / y
# invfn = lambda x: 1 / x # invfn = lambda x: 1 / x
# Canonizer(Mul, Div, Inv, mulfn, divfn, invfn, group_powers).optimize(g) # Canonizer(mul, div, inv, mulfn, divfn, invfn, group_powers).optimize(g)
# print g # print g, g.orphans
# addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs) # addfn = lambda *inputs: reduce(lambda x, y: x + y, (0,) + inputs)
# subfn = lambda x, y: x - y # subfn = lambda x, y: x - y
# negfn = lambda x: -x # negfn = lambda x: -x
# Canonizer(Add, Sub, Neg, addfn, subfn, negfn).optimize(g) # Canonizer(add, sub, neg, addfn, subfn, negfn).optimize(g)
# print g # print g, g.orphans
# pow2one_float.optimize(g) # pow2one_float.optimize(g)
# pow2x_float.optimize(g) # pow2x_float.optimize(g)
# print g # print g, g.orphans
......
差异被折叠。
...@@ -16,10 +16,10 @@ def exec_opt(inputs, outputs, features=[]): ...@@ -16,10 +16,10 @@ def exec_opt(inputs, outputs, features=[]):
exec_opt.optimizer = None exec_opt.optimizer = None
class _DefaultOptimizer(object): class _DefaultOptimizer(object):
const = gof.opt.ConstantFinder() #const = gof.opt.ConstantFinder()
merge = gof.opt.MergeOptimizer() merge = gof.opt.MergeOptimizer()
def __call__(self, env): def __call__(self, env):
self.const(env) #self.const(env)
self.merge(env) self.merge(env)
default_optimizer = _DefaultOptimizer() default_optimizer = _DefaultOptimizer()
...@@ -31,7 +31,7 @@ def linker_cls_python_and_c(env): ...@@ -31,7 +31,7 @@ def linker_cls_python_and_c(env):
"""Use this as the linker_cls argument to Function.__init__ to compare """Use this as the linker_cls argument to Function.__init__ to compare
python and C implementations""" python and C implementations"""
def checker(x, y): def checker(x, y):
x, y = x.data, y.data x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
...@@ -105,17 +105,23 @@ class Function: ...@@ -105,17 +105,23 @@ class Function:
#print 'orphans', orphans #print 'orphans', orphans
#print 'ops', gof.graph.ops(inputs, outputs) #print 'ops', gof.graph.ops(inputs, outputs)
env = gof.env.Env(inputs, outputs, features + [gof.EquivTool], consistency_check = True) env = gof.env.Env(inputs, outputs)
#print 'orphans in env', env.orphans() #print 'orphans in env', env.orphans()
env = env.clone(clone_inputs=True) env, equiv = env.clone_get_equiv(clone_inputs=True)
for feature in features:
env.extend(feature(env))
env.extend(gof.DestroyHandler(env))
#print 'orphans after clone', env.orphans() #print 'orphans after clone', env.orphans()
for d, o in zip(orphan_data, [env.equiv(orphan) for orphan in orphans]): for d, o in zip(orphan_data, [equiv[orphan] for orphan in orphans]):
#print 'assigning orphan value', d #print 'assigning orphan value', d
o.data = d #o.data = d
new_o = gof.Constant(o.type, d)
env.replace(o, new_o)
assert new_o in env.orphans
# optimize and link the cloned env # optimize and link the cloned env
if None is not optimizer: if None is not optimizer:
...@@ -127,11 +133,9 @@ class Function: ...@@ -127,11 +133,9 @@ class Function:
self.__dict__.update(locals()) self.__dict__.update(locals())
if profiler is None: if profiler is None:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(unpack_single=unpack_single)
unpack_single=unpack_single)
else: else:
self.fn = linker.make_function(inplace=True, self.fn = linker.make_function(unpack_single=unpack_single,
unpack_single=unpack_single,
profiler=profiler) profiler=profiler)
self.inputs = env.inputs self.inputs = env.inputs
self.outputs = env.outputs self.outputs = env.outputs
...@@ -146,16 +150,6 @@ class Function: ...@@ -146,16 +150,6 @@ class Function:
def __call__(self, *args): def __call__(self, *args):
return self.fn(*args) return self.fn(*args)
def __copy__(self):
return Function(self.inputs, self.outputs,
features = self.features,
optimizer = self.optimizer,
linker_cls = self.linker_cls,
profiler = self.profiler,
unpack_single = self.unpack_single,
except_unreachable_input = self.except_unreachable_input,
keep_locals = self.keep_locals)
def eval_outputs(outputs, def eval_outputs(outputs,
features = [], features = [],
...@@ -171,20 +165,23 @@ def eval_outputs(outputs, ...@@ -171,20 +165,23 @@ def eval_outputs(outputs,
else: else:
return [] return []
inputs = list(gof.graph.inputs(outputs)) inputs = gof.graph.inputs(outputs)
in_data = [i.data for i in inputs if i.data is not None] if any(not isinstance(input, gof.Constant) for input in inputs):
raise TypeError("Cannot evaluate outputs because some of the leaves are not Constant.", outputs)
in_data = [i.data for i in inputs]
#print 'in_data = ', in_data #print 'in_data = ', in_data
if len(inputs) != len(in_data): if len(inputs) != len(in_data):
raise Exception('some input data is unknown') raise Exception('some input data is unknown')
env = gof.env.Env(inputs, outputs, features, consistency_check = True) env = gof.env.Env(inputs, outputs)
env.replace_all(dict([(i, i.type()) for i in inputs]))
env = env.clone(clone_inputs=True) env = env.clone(clone_inputs=True)
_mark_indestructible(env.outputs) _mark_indestructible(env.outputs)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
linker = linker_cls(env) linker = linker_cls(env)
fn = linker.make_function(inplace=True, unpack_single=unpack_single) fn = linker.make_function(unpack_single=unpack_single)
rval = fn(*in_data) rval = fn(*in_data)
return rval return rval
......
差异被折叠。
import op, result, ext, link, env, features, toolbox, graph, cc, opt import op, type, ext, link, env, features, toolbox, graph, cc, opt
from op import * from op import *
from result import * from graph import Apply, Result, Constant, as_apply, as_result
from type import *
from ext import * from ext import *
from link import * from link import *
from env import * from env import *
......
...@@ -3,27 +3,14 @@ import unittest ...@@ -3,27 +3,14 @@ import unittest
from link import PerformLinker, Profiler from link import PerformLinker, Profiler
from cc import * from cc import *
from result import Result from type import Type
from graph import Result, as_result, Apply, Constant
from op import Op from op import Op
from env import Env from env import Env
class Double(Result): class TDouble(Type):
def filter(self, data):
def __init__(self, data, name = "oignon"): return float(data)
Result.__init__(self, role = None, name = name)
assert isinstance(data, float)
self.data = data
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __copy__(self):
return Double(self.data, self.name)
# def c_is_simple(self): return True
def c_declare(self, name, sub): def c_declare(self, name, sub):
return "double %(name)s; void* %(name)s_bad_thing;" % locals() return "double %(name)s; void* %(name)s_bad_thing;" % locals()
...@@ -35,8 +22,8 @@ class Double(Result): ...@@ -35,8 +22,8 @@ class Double(Result):
//printf("Initializing %(name)s\\n"); //printf("Initializing %(name)s\\n");
""" % locals() """ % locals()
def c_literal(self): def c_literal(self, data):
return str(self.data) return str(data)
def c_extract(self, name, sub): def c_extract(self, name, sub):
return """ return """
...@@ -65,116 +52,119 @@ class Double(Result): ...@@ -65,116 +52,119 @@ class Double(Result):
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" % locals() """ % locals()
tdouble = TDouble()
def double(name):
return Result(tdouble, None, None, name = name)
class MyOp(Op): class MyOp(Op):
nin = -1 def __init__(self, nin, name):
self.nin = nin
self.name = name
def __init__(self, *inputs): def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, Double): if input.type is not tdouble:
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [double(self.name + "_R")]
self.outputs = [Double(0.0, self.__class__.__name__ + "_R")] return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs)
class Unary(MyOp): class Unary(MyOp):
nin = 1 def __init__(self):
# def c_var_names(self): MyOp.__init__(self, 1, self.__class__.__name__)
# return [['x'], ['z']]
class Binary(MyOp): class Binary(MyOp):
nin = 2 def __init__(self):
# def c_var_names(self): MyOp.__init__(self, 2, self.__class__.__name__)
# return [['x', 'y'], ['z']]
class Add(Binary): class Add(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s + %(y)s;" % locals() return "%(z)s = %(x)s + %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data + self.inputs[1].data return x + y
add = Add()
class Sub(Binary): class Sub(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s - %(y)s;" % locals() return "%(z)s = %(x)s - %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = -10 # erroneous return -10 # erroneous (most of the time)
sub = Sub()
class Mul(Binary): class Mul(Binary):
def c_code(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return "%(z)s = %(x)s * %(y)s;" % locals() return "%(z)s = %(x)s * %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data * self.inputs[1].data return x * y
mul = Mul()
class Div(Binary): class Div(Binary):
def c_validate_update(self, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
return """
if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s
}
""" % dict(locals(), **sub)
def c_code(self, (x, y), (z, ), sub):
return "%(z)s = %(x)s / %(y)s;" % locals() return "%(z)s = %(x)s / %(y)s;" % locals()
def perform(self): def impl(self, x, y):
self.outputs[0].data = self.inputs[0].data / self.inputs[1].data return x / y
div = Div()
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(Double(1.0, 'x')) x = double('x')
y = modes.build(Double(2.0, 'y')) y = double('y')
z = modes.build(Double(3.0, 'z')) z = double('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_CLinker(unittest.TestCase): class _test_CLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e])) lnk = CLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_orphan(self): # def test_orphan(self):
x, y, z = inputs() # x, y, z = inputs()
z.data = 4.12345678 # z = Constant(tdouble, 4.12345678)
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) # e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) # lnk = CLinker(Env([x, y], [e]))
fn = lnk.make_function() # fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) # self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined # print lnk.code_gen()
# self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
def test_literal_inlining(self): def test_literal_inlining(self):
x, y, z = inputs() x, y, z = inputs()
z.data = 4.12345678 z = Constant(tdouble, 4.12345678)
z.constant = True # this should tell the compiler to inline z as a literal
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) lnk = CLinker(Env([x, y], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9) self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
def test_single_op(self): def test_single_node(self):
x, y, z = inputs() x, y, z = inputs()
op = Add(x, y) node = add.make_node(x, y)
lnk = CLinker(op) lnk = CLinker(Env(node.inputs, node.outputs))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9) self.failUnless(fn(2.0, 7.0) == 9)
def test_dups(self): def test_dups(self):
# Testing that duplicate inputs are allowed. # Testing that duplicate inputs are allowed.
x, y, z = inputs() x, y, z = inputs()
op = Add(x, x) e = add(x, x)
lnk = CLinker(op) lnk = CLinker(Env([x, x], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4) self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined # note: for now the behavior of fn(2.0, 7.0) is undefined
...@@ -183,7 +173,7 @@ class _test_CLinker(unittest.TestCase): ...@@ -183,7 +173,7 @@ class _test_CLinker(unittest.TestCase):
# Testing that duplicates are allowed inside the graph # Testing that duplicates are allowed inside the graph
x, y, z = inputs() x, y, z = inputs()
e = add(mul(y, y), add(x, z)) e = add(mul(y, y), add(x, z))
lnk = CLinker(env([x, y, z], [e])) lnk = CLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(1.0, 2.0, 3.0) == 8.0) self.failUnless(fn(1.0, 2.0, 3.0) == 8.0)
...@@ -194,16 +184,25 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -194,16 +184,25 @@ class _test_OpWiseCLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e])) lnk = OpWiseCLinker(Env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_constant(self):
x, y, z = inputs()
x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z))
lnk = OpWiseCLinker(Env([y, z], [e]))
fn = lnk.make_function()
res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res)
class MyExc(Exception): class MyExc(Exception):
pass pass
def _my_checker(x, y): def _my_checker(x, y):
if x.data != y.data: if x[0] != y[0]:
raise MyExc("Output mismatch.", {'performlinker': x.data, 'clinker': y.data}) raise MyExc("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class _test_DualLinker(unittest.TestCase): class _test_DualLinker(unittest.TestCase):
...@@ -211,7 +210,7 @@ class _test_DualLinker(unittest.TestCase): ...@@ -211,7 +210,7 @@ class _test_DualLinker(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(env([x, y, z], [e]), checker = _my_checker) lnk = DualLinker(Env([x, y, z], [e]), checker = _my_checker)
fn = lnk.make_function() fn = lnk.make_function()
res = fn(7.2, 1.5, 3.0) res = fn(7.2, 1.5, 3.0)
self.failUnless(res == 15.3, res) self.failUnless(res == 15.3, res)
...@@ -219,7 +218,7 @@ class _test_DualLinker(unittest.TestCase): ...@@ -219,7 +218,7 @@ class _test_DualLinker(unittest.TestCase):
def test_mismatch(self): def test_mismatch(self):
x, y, z = inputs() x, y, z = inputs()
e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python e = sub(mul(x, y), mul(y, z)) # sub is correct in C but erroneous in Python
g = env([x, y, z], [e]) g = Env([x, y, z], [e])
lnk = DualLinker(g, checker = _my_checker) lnk = DualLinker(g, checker = _my_checker)
fn = lnk.make_function() fn = lnk.make_function()
...@@ -238,14 +237,14 @@ class _test_DualLinker(unittest.TestCase): ...@@ -238,14 +237,14 @@ class _test_DualLinker(unittest.TestCase):
else: else:
self.fail() self.fail()
def test_orphan(self): # def test_orphan(self):
x, y, z = inputs() # x, y, z = inputs()
x.data = 7.2 # x = Constant(tdouble, 7.2, name = 'x')
e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python # e = add(mul(x, y), mul(y, z)) # add and mul are correct in C and in Python
lnk = DualLinker(env([y, z], [e]), checker = _my_checker) # lnk = DualLinker(Env([y, z], [e]), checker = _my_checker)
fn = lnk.make_function() # fn = lnk.make_function()
res = fn(1.5, 3.0) # res = fn(1.5, 3.0)
self.failUnless(res == 15.3, res) # self.failUnless(res == 15.3, res)
......
import unittest import unittest
from result import Result from type import Type
from graph import Result, as_result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
...@@ -9,62 +10,64 @@ from ext import * ...@@ -9,62 +10,64 @@ from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import EquivTool from toolbox import EquivTool
from _test_result import MyResult from copy import copy
class MyOp(Op): #from _test_result import MyResult
nin = -1
def __init__(self, *inputs):
assert len(inputs) == self.nin
for input in inputs:
if not isinstance(input, MyResult):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [MyResult(self.__class__.__name__ + "_R")]
class Sigmoid(MyOp): class MyType(Type):
nin = 1
class TransposeView(MyOp, Viewer): def filter(self, data):
nin = 1 return data
def view_map(self):
return {self.outputs[0]: [self.inputs[0]]}
class Add(MyOp): def __eq__(self, other):
nin = 2 return isinstance(other, MyType)
class AddInPlace(MyOp, Destroyer):
nin = 2
def destroyed_inputs(self):
return self.inputs[:1]
class Dot(MyOp): def MyResult(name):
nin = 2 return Result(MyType(), None, None, name = name)
# dtv_elim = PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x') class MyOp(Op):
# AddCls = Add def __init__(self, nin, name, vmap = {}, dmap = {}):
# AddInPlaceCls = AddInPlace self.nin = nin
self.name = name
self.destroy_map = dmap
self.view_map = vmap
def make_node(self, *inputs):
assert len(inputs) == self.nin
inputs = map(as_result, inputs)
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyResult(self.name + "_R")]
return Apply(self, inputs, outputs)
# a2i = OpSubOptimizer(Add, AddInPlace) def __str__(self):
# i2a = OpSubOptimizer(AddInPlace, Add) return self.name
# t2s = OpSubOptimizer(TransposeView, Sigmoid)
# s2t = OpSubOptimizer(Sigmoid, TransposeView)
sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]})
add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]})
dot = MyOp(2, 'Dot')
import modes
modes.make_constructors(globals()) #, name_filter = lambda x:x)
def inputs(): def inputs():
x = modes.build(MyResult('x')) x = MyResult('x')
y = modes.build(MyResult('y')) y = MyResult('y')
z = modes.build(MyResult('z')) z = MyResult('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True): _Env = Env
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate) def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs)
e.extend(EquivTool(e))
e.extend(DestroyHandler(e), validate = validate)
return e
class FailureWatch: class FailureWatch:
...@@ -82,7 +85,7 @@ class _test_all(unittest.TestCase): ...@@ -82,7 +85,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = add(add_in_place(x, y), add_in_place(x, y)) e = add(add_in_place(x, y), add_in_place(x, y))
try: try:
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
self.fail() self.fail()
except InconsistencyError, e: except InconsistencyError, e:
pass pass
...@@ -90,10 +93,10 @@ class _test_all(unittest.TestCase): ...@@ -90,10 +93,10 @@ class _test_all(unittest.TestCase):
def test_multi_destroyers_through_views(self): def test_multi_destroyers_through_views(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(Add, AddInPlace, fail).optimize(g) OpSubOptimizer(add, add_in_place, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # should have succeeded once and failed once assert fail.failures == 1 # should have succeeded once and failed once
...@@ -102,7 +105,7 @@ class _test_all(unittest.TestCase): ...@@ -102,7 +105,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e1 = add(x, y) e1 = add(x, y)
e2 = add(y, x) e2 = add(y, x)
g = env([x,y,z], [e1, e2]) g = Env([x,y,z], [e1, e2])
chk = g.checkpoint() chk = g.checkpoint()
assert g.consistent() assert g.consistent()
g.replace(e1, add_in_place(x, y)) g.replace(e1, add_in_place(x, y))
...@@ -126,30 +129,30 @@ class _test_all(unittest.TestCase): ...@@ -126,30 +129,30 @@ class _test_all(unittest.TestCase):
def test_long_destroyers_loop(self): def test_long_destroyers_loop(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x)) e = dot(dot(add_in_place(x,y), add_in_place(y,z)), add(z,x))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
OpSubOptimizer(Add, AddInPlace).optimize(g) OpSubOptimizer(add, add_in_place).optimize(g)
assert g.consistent() assert g.consistent()
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that! assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x)) e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try: try:
g2 = env([x,y,z], [e2]) g2 = Env([x,y,z], [e2])
self.fail() self.fail()
except InconsistencyError: except InconsistencyError:
pass pass
def test_usage_loop(self): def test_usage_loop(self):
x, y, z = inputs() x, y, z = inputs()
g = env([x,y,z], [dot(add_in_place(x, z), x)], False) g = Env([x,y,z], [dot(add_in_place(x, z), x)], False)
assert not g.consistent() assert not g.consistent()
OpSubOptimizer(AddInPlace, Add).optimize(g) # replace AddInPlace with Add OpSubOptimizer(add_in_place, add).optimize(g) # replace add_in_place with add
assert g.consistent() assert g.consistent()
def test_usage_loop_through_views(self): def test_usage_loop_through_views(self):
x, y, z = inputs() x, y, z = inputs()
aip = add_in_place(x, y) aip = add_in_place(x, y)
e = dot(aip, transpose_view(x)) e = dot(aip, transpose_view(x))
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(aip, add(x, z)) g.replace(aip, add(x, z))
assert g.consistent() assert g.consistent()
...@@ -158,7 +161,7 @@ class _test_all(unittest.TestCase): ...@@ -158,7 +161,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(transpose_view(sigmoid(x)))) e0 = transpose_view(transpose_view(transpose_view(sigmoid(x))))
e = dot(add_in_place(x,y), transpose_view(e0)) e = dot(add_in_place(x,y), transpose_view(e0))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() # because sigmoid can do the copy assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x, False) g.replace(e0, x, False)
assert not g.consistent() # we cut off the path to the sigmoid assert not g.consistent() # we cut off the path to the sigmoid
...@@ -166,22 +169,23 @@ class _test_all(unittest.TestCase): ...@@ -166,22 +169,23 @@ class _test_all(unittest.TestCase):
def test_usage_loop_insert_views(self): def test_usage_loop_insert_views(self):
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x)))))) e = dot(add_in_place(x, add(y, z)), sigmoid(sigmoid(sigmoid(sigmoid(sigmoid(x))))))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
fail = FailureWatch() fail = FailureWatch()
OpSubOptimizer(Sigmoid, TransposeView, fail).optimize(g) OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g)
assert g.consistent() assert g.consistent()
assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain assert fail.failures == 1 # it must keep one sigmoid in the long sigmoid chain
def test_misc(self): def test_misc(self):
x, y, z = inputs() x, y, z = inputs()
e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) e = transpose_view(transpose_view(transpose_view(transpose_view(x))))
g = env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() assert g.consistent()
chk = g.checkpoint() chk = g.checkpoint()
PatternOptimizer((TransposeView, (TransposeView, 'x')), 'x').optimize(g) PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e), add(x,y)) g.replace(g.equiv(e), add(x,y))
print g
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False) g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False)
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
...@@ -193,8 +197,10 @@ class _test_all(unittest.TestCase): ...@@ -193,8 +197,10 @@ class _test_all(unittest.TestCase):
def test_indestructible(self): def test_indestructible(self):
x, y, z = inputs() x, y, z = inputs()
x.indestructible = True x.indestructible = True
x = copy(x)
assert x.indestructible # checking if indestructible survives the copy!
e = add_in_place(x, y) e = add_in_place(x, y)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e, add(x, y)) g.replace(e, add(x, y))
assert g.consistent() assert g.consistent()
...@@ -204,7 +210,7 @@ class _test_all(unittest.TestCase): ...@@ -204,7 +210,7 @@ class _test_all(unittest.TestCase):
x.indestructible = True x.indestructible = True
tv = transpose_view(x) tv = transpose_view(x)
e = add_in_place(tv, y) e = add_in_place(tv, y)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, sigmoid(x)) g.replace(tv, sigmoid(x))
assert g.consistent() assert g.consistent()
...@@ -215,7 +221,7 @@ class _test_all(unittest.TestCase): ...@@ -215,7 +221,7 @@ class _test_all(unittest.TestCase):
e2 = transpose_view(transpose_view(e1)) e2 = transpose_view(transpose_view(e1))
e3 = add_in_place(e2, y) e3 = add_in_place(e2, y)
e4 = add_in_place(e1, z) e4 = add_in_place(e1, z)
g = env([x,y,z], [e3, e4], False) g = Env([x,y,z], [e3, e4], False)
assert not g.consistent() assert not g.consistent()
g.replace(e2, transpose_view(x), False) g.replace(e2, transpose_view(x), False)
assert not g.consistent() assert not g.consistent()
...@@ -224,7 +230,7 @@ class _test_all(unittest.TestCase): ...@@ -224,7 +230,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = add_in_place(x, y) e0 = add_in_place(x, y)
e = dot(sigmoid(e0), transpose_view(x)) e = dot(sigmoid(e0), transpose_view(x))
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
...@@ -236,7 +242,7 @@ class _test_all(unittest.TestCase): ...@@ -236,7 +242,7 @@ class _test_all(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
e = dot(sigmoid(add_in_place(x, y)), e0) e = dot(sigmoid(add_in_place(x, y)), e0)
g = env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0, False)
......
...@@ -4,21 +4,18 @@ import unittest ...@@ -4,21 +4,18 @@ import unittest
from graph import * from graph import *
from op import Op from op import Op
from result import Result from type import Type
from graph import Result
class MyResult(Result):
class MyType(Type):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
Result.__init__(self, role = None )
self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other) return isinstance(other, MyType) and other.thingy == self.thingy
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
...@@ -26,37 +23,78 @@ class MyResult(Result): ...@@ -26,37 +23,78 @@ class MyResult(Result):
def __repr__(self): def __repr__(self):
return str(self.thingy) return str(self.thingy)
def MyResult(thingy):
return Result(MyType(thingy), None, None)
class MyOp(Op): class MyOp(Op):
def __init__(self, *inputs): def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
print input, input.type, type(input), type(input.type)
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyResult(sum([input.type.thingy for input in inputs]))]
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] return Apply(self, inputs, outputs)
def __str__(self):
return self.__class__.__name__
MyOp = MyOp()
# class MyResult(Result):
# def __init__(self, thingy):
# self.thingy = thingy
# Result.__init__(self, role = None )
# self.data = [self.thingy]
# def __eq__(self, other):
# return self.same_properties(other)
# def same_properties(self, other):
# return isinstance(other, MyResult) and other.thingy == self.thingy
# def __str__(self):
# return str(self.thingy)
# def __repr__(self):
# return str(self.thingy)
# class MyOp(Op):
# def __init__(self, *inputs):
# for input in inputs:
# if not isinstance(input, MyResult):
# raise Exception("Error 1")
# self.inputs = inputs
# self.outputs = [MyResult(sum([input.thingy for input in inputs]))]
class _test_inputs(unittest.TestCase): class _test_inputs(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert inputs(op.outputs) == set([r1, r2]) assert inputs(node.outputs) == set([r1, r2])
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert inputs(op2.outputs) == set([r1, r2, r5]) assert inputs(node2.outputs) == set([r1, r2, r5])
def test_unreached_inputs(self): def test_unreached_inputs(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
try: try:
# function doesn't raise if we put False instead of True # function doesn't raise if we put False instead of True
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True) ro = results_and_orphans([r1, r2, node2.outputs[0]], node.outputs, True)
self.fail() self.fail()
except Exception, e: except Exception, e:
if e[0] is results_and_orphans.E_unreached: if e[0] is results_and_orphans.E_unreached:
...@@ -68,70 +106,84 @@ class _test_orphans(unittest.TestCase): ...@@ -68,70 +106,84 @@ class _test_orphans(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) assert orphans([r1, r2], node2.outputs) == set([r5])
class _test_as_string(unittest.TestCase): class _test_as_string(unittest.TestCase):
leaf_formatter = str leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda op, argstrings: "%s(%s)" % (op.__class__.__name__, node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings)) ", ".join(argstrings))
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
def test_straightforward(self): def test_straightforward(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert as_string([r1, r2], op.outputs) == ["MyOp(1, 2)"] assert self.str([r1, r2], node.outputs) == ["MyOp(1, 2)"]
def test_deep(self): def test_deep(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(MyOp(1, 2), 5)"]
def test_multiple_references(self): def test_multiple_references(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert as_string([r1, r2, r5], op2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(1, 2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], op.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert as_string(op.outputs, op2.outputs) == ["MyOp(3, 3)"] assert self.str(node.outputs, node2.outputs) == ["MyOp(3, 3)"]
assert as_string(op2.inputs, op2.outputs) == ["MyOp(3, 3)"] assert self.str(node2.inputs, node2.outputs) == ["MyOp(3, 3)"]
class _test_clone(unittest.TestCase): class _test_clone(unittest.TestCase):
leaf_formatter = lambda self, leaf: str(leaf.type)
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op,
", ".join(argstrings))
def str(self, inputs, outputs):
return as_string(inputs, outputs,
leaf_formatter = self.leaf_formatter,
node_formatter = self.node_formatter)
def test_accurate(self): def test_accurate(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyResult(1), MyResult(2)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
new = clone([r1, r2], op.outputs) new = clone([r1, r2], node.outputs)
assert as_string([r1, r2], new) == ["MyOp(1, 2)"] assert self.str([r1, r2], new) == ["MyOp(1, 2)"]
def test_copy(self): def test_copy(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
op2 = MyOp(op.outputs[0], r5) node2 = MyOp.make_node(node.outputs[0], r5)
new = clone([r1, r2, r5], op2.outputs) new = clone([r1, r2, r5], node2.outputs)
assert op2.outputs[0] == new[0] and op2.outputs[0] is not new[0] # the new output is like the old one but not the same object assert node2.outputs[0].type == new[0].type and node2.outputs[0] is not new[0] # the new output is like the old one but not the same object
assert op2 is not new[0].owner # the new output has a new owner assert node2 is not new[0].owner # the new output has a new owner
assert new[0].owner.inputs[1] is r5 # the inputs are not copied assert new[0].owner.inputs[1] is r5 # the inputs are not copied
assert new[0].owner.inputs[0] == op.outputs[0] and new[0].owner.inputs[0] is not op.outputs[0] # check that we copied deeper too assert new[0].owner.inputs[0].type == node.outputs[0].type and new[0].owner.inputs[0] is not node.outputs[0] # check that we copied deeper too
def test_not_destructive(self): def test_not_destructive(self):
# Checks that manipulating a cloned graph leaves the original unchanged. # Checks that manipulating a cloned graph leaves the original unchanged.
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5) r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(MyOp(r1, r2).outputs[0], r5) node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
new = clone([r1, r2, r5], op.outputs) new = clone([r1, r2, r5], node.outputs)
new_op = new[0].owner new_node = new[0].owner
new_op.inputs = MyResult(7), MyResult(8) new_node.inputs = MyResult(7), MyResult(8)
assert as_string(inputs(new_op.outputs), new_op.outputs) == ["MyOp(7, 8)"] assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(7, 8)"]
assert as_string(inputs(op.outputs), op.outputs) == ["MyOp(MyOp(1, 2), 5)"] assert self.str(inputs(node.outputs), node.outputs) == ["MyOp(MyOp(1, 2), 5)"]
......
...@@ -2,70 +2,67 @@ ...@@ -2,70 +2,67 @@
import unittest import unittest
from result import Result from graph import Result, as_result, Apply
from type import Type
from op import Op from op import Op
from env import Env from env import Env
from link import * from link import *
from _test_result import Double #from _test_result import Double
class MyOp(Op):
nin = -1
def __init__(self, *inputs): class TDouble(Type):
assert len(inputs) == self.nin def filter(self, data):
for input in inputs: return float(data)
if not isinstance(input, Double):
raise Exception("Error 1")
self.inputs = inputs
self.outputs = [Double(0.0, self.__class__.__name__ + "_R")]
def perform(self): tdouble = TDouble()
self.outputs[0].data = self.impl(*[input.data for input in self.inputs])
def double(name):
return Result(tdouble, None, None, name = name)
class Unary(MyOp):
nin = 1
class Binary(MyOp): class MyOp(Op):
nin = 2
def __init__(self, nin, name, impl = None):
self.nin = nin
self.name = name
if impl:
self.impl = impl
class Add(Binary): def make_node(self, *inputs):
def impl(self, x, y): assert len(inputs) == self.nin
return x + y inputs = map(as_result, inputs)
for input in inputs:
if input.type is not tdouble:
raise Exception("Error 1")
outputs = [double(self.name + "_R")]
return Apply(self, inputs, outputs)
class Sub(Binary): def __str__(self):
def impl(self, x, y): return self.name
return x - y
class Mul(Binary): def perform(self, node, inputs, (out, )):
def impl(self, x, y): out[0] = self.impl(*inputs)
return x * y
class Div(Binary): add = MyOp(2, 'Add', lambda x, y: x + y)
def impl(self, x, y): sub = MyOp(2, 'Sub', lambda x, y: x - y)
return x / y mul = MyOp(2, 'Mul', lambda x, y: x * y)
div = MyOp(2, 'Div', lambda x, y: x / y)
class RaiseErr(Unary): def notimpl(self, x):
def impl(self, x):
raise NotImplementedError() raise NotImplementedError()
raise_err = MyOp(1, 'RaiseErr', notimpl)
import modes
modes.make_constructors(globals())
def inputs(): def inputs():
x = modes.build(Double(1.0, 'x')) x = double('x')
y = modes.build(Double(2.0, 'y')) y = double('y')
z = modes.build(Double(3.0, 'z')) z = double('z')
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker(env) lnk = PerformLinker(env)
return lnk return lnk
...@@ -73,58 +70,79 @@ def perform_linker(env): ...@@ -73,58 +70,79 @@ def perform_linker(env):
class _test_PerformLinker(unittest.TestCase): class _test_PerformLinker(unittest.TestCase):
def test_thunk_inplace(self): def test_thunk(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(True) fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk()
fn() i[0].data = 1
assert e.data == 1.5 i[1].data = 2
def test_thunk_not_inplace(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
fn() fn()
assert o[0].data == 1.5 assert o[0].data == 1.5
assert e.data != 1.5
def test_function(self): def test_function(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn = perform_linker(env([x, y, z], [e])).make_function() fn = perform_linker(Env([x, y, z], [e])).make_function()
assert fn(1.0, 2.0, 3.0) == 1.5 assert fn(1.0, 2.0, 3.0) == 1.5
assert e.data != 1.5 # not inplace
def test_constant(self):
x, y, z = inputs()
y = Constant(tdouble, 2.0)
e = mul(add(x, y), div(x, y))
fn = perform_linker(Env([x], [e])).make_function()
assert fn(1.0) == 1.5
def test_input_output_same(self): def test_input_output_same(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(env([e], [e])).make_function() fn = perform_linker(Env([e], [e])).make_function()
self.failUnless(1 is fn(1)) self.failUnless(1.0 is fn(1.0))
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(env([x, a], [e])).make_function() fn = perform_linker(Env([x, y, a], [e])).make_function()
self.failUnless(fn(1.0,9.0) == 4.5) self.failUnless(fn(1.0,2.0,9.0) == 4.5)
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x,y,z = inputs()
a = add(x,y) a = add(x,y)
r = RaiseErr(a).out r = raise_err(a)
e = add(r,a) e = add(r,a)
fn = perform_linker(env([x, y,r], [e])).make_function() fn = perform_linker(Env([x, y,r], [e])).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def test_disconnected_input_output(self): # def test_disconnected_input_output(self):
x,y,z = inputs() # x,y,z = inputs()
a = add(x,y) # a = add(x,y)
a.data = 3.0 # simulate orphan calculation # a.data = 3.0 # simulate orphan calculation
fn = perform_linker(env([z], [a])).make_function(inplace=True) # fn = perform_linker(env([z], [a])).make_function(inplace=True)
self.failUnless(fn(1.0) == 3.0) # self.failUnless(fn(1.0) == 3.0)
self.failUnless(fn(2.0) == 3.0) # self.failUnless(fn(2.0) == 3.0)
# def test_thunk_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(Env([x, y, z], [e])).make_thunk(True)
# fn()
# assert e.data == 1.5
# def test_thunk_not_inplace(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn, i, o = perform_linker(env([x, y, z], [e])).make_thunk(False)
# fn()
# assert o[0].data == 1.5
# assert e.data != 1.5
# def test_function(self):
# x, y, z = inputs()
# e = mul(add(x, y), div(x, y))
# fn = perform_linker(env([x, y, z], [e])).make_function()
# assert fn(1.0, 2.0, 3.0) == 1.5
# assert e.data != 1.5 # not inplace
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
from modes import * from modes import *
from result import Result from graph import Result
from op import Op from op import Op
from env import Env from env import Env
...@@ -86,53 +86,53 @@ def env(inputs, outputs, validate = True): ...@@ -86,53 +86,53 @@ def env(inputs, outputs, validate = True):
return Env(inputs, outputs, features = [], consistency_check = validate) return Env(inputs, outputs, features = [], consistency_check = validate)
class _test_Modes(unittest.TestCase): # class _test_Modes(unittest.TestCase):
def test_0(self): # def test_0(self):
x, y, z = inputs(build) # x, y, z = inputs(build)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 0.0 # assert e.data == 0.0
def test_1(self): # def test_1(self):
x, y, z = inputs(build_eval) # x, y, z = inputs(build_eval)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 6.0 # assert e.data == 6.0
def test_2(self): # def test_2(self):
x, y, z = inputs(eval) # x, y, z = inputs(eval)
e = add(add(x, y), z) # e = add(add(x, y), z)
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add_R]" # assert str(g) == "[Add_R]"
assert e.data == 6.0 # assert e.data == 6.0
def test_3(self): # def test_3(self):
x, y, z = inputs(build) # x, y, z = inputs(build)
e = x + y + z # e = x + y + z
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(Add(x, y), z)]" # assert str(g) == "[Add(Add(x, y), z)]"
assert e.data == 0.0 # assert e.data == 0.0
def test_4(self): # def test_4(self):
x, y, z = inputs(build_eval) # x, y, z = inputs(build_eval)
e = x + 34.0 # e = x + 34.0
g = env([x, y, z], [e]) # g = env([x, y, z], [e])
assert str(g) == "[Add(x, oignon)]" # assert str(g) == "[Add(x, oignon)]"
assert e.data == 35.0 # assert e.data == 35.0
def test_5(self): # def test_5(self):
xb, yb, zb = inputs(build) # xb, yb, zb = inputs(build)
xe, ye, ze = inputs(eval) # xe, ye, ze = inputs(eval)
try: # try:
e = xb + ye # e = xb + ye
except TypeError: # except TypeError:
# Trying to add inputs from different modes is forbidden # # Trying to add inputs from different modes is forbidden
pass # pass
else: # else:
raise Exception("Expected an error.") # raise Exception("Expected an error.")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2,21 +2,19 @@ ...@@ -2,21 +2,19 @@
import unittest import unittest
from copy import copy from copy import copy
from op import * from op import *
from result import Result from type import Type, Generic
from graph import Apply, as_result
#from result import Result
class MyResult(Result):
class MyType(Type):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
Result.__init__(self, role = None)
self.data = [self.thingy]
def __eq__(self, other): def __eq__(self, other):
return self.same_properties(other) return type(other) == type(self) and other.thingy == self.thingy
def same_properties(self, other):
return isinstance(other, MyResult) and other.thingy == self.thingy
def __str__(self): def __str__(self):
return str(self.thingy) return str(self.thingy)
...@@ -27,32 +25,35 @@ class MyResult(Result): ...@@ -27,32 +25,35 @@ class MyResult(Result):
class MyOp(Op): class MyOp(Op):
def __init__(self, *inputs): def make_node(self, *inputs):
inputs = map(as_result, inputs)
for input in inputs: for input in inputs:
if not isinstance(input, MyResult): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
self.inputs = inputs outputs = [MyType(sum([input.type.thingy for input in inputs]))()]
self.outputs = [MyResult(sum([input.thingy for input in inputs]))] return Apply(self, inputs, outputs)
MyOp = MyOp()
class _test_Op(unittest.TestCase): class _test_Op(unittest.TestCase):
# Sanity tests # Sanity tests
def test_sanity_0(self): def test_sanity_0(self):
r1, r2 = MyResult(1), MyResult(2) r1, r2 = MyType(1)(), MyType(2)()
op = MyOp(r1, r2) node = MyOp.make_node(r1, r2)
assert op.inputs == [r1, r2] # Are the inputs what I provided? assert [x for x in node.inputs] == [r1, r2] # Are the inputs what I provided?
assert op.outputs == [MyResult(3)] # Are the outputs what I expect? assert [x.type for x in node.outputs] == [MyType(3)] # Are the outputs what I expect?
assert op.outputs[0].owner is op and op.outputs[0].index == 0 assert node.outputs[0].owner is node and node.outputs[0].index == 0
# validate_update # validate
def test_validate_update(self): def test_validate(self):
try: try:
MyOp(Result(), MyResult(1)) # MyOp requires MyResult instances MyOp(Generic()(), MyType(1)()) # MyOp requires MyType instances
except Exception, e:
assert str(e) == "Error 1"
else:
raise Exception("Expected an exception") raise Exception("Expected an exception")
except Exception, e:
if str(e) != "Error 1":
raise
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论