orphans are handled correctly

上级 66e6a25a
......@@ -7,7 +7,8 @@ class Double(gof.result.ResultBase):
def __init__(self, data, name = "oignon"):
assert isinstance(data, float)
gof.result.ResultBase.__init__(self, role = None, data = data, name = name)
gof.result.ResultBase.__init__(self, role = None, name = name)
self.data = data
def __str__(self):
return self.name
......@@ -15,6 +16,9 @@ class Double(gof.result.ResultBase):
def __repr__(self):
return self.name
def __copy__(self):
return self.__class__(self.data, self.name)
class MyOp(gof.op.Op):
nin = -1
......@@ -61,31 +65,32 @@ def perform_linker(env):
lnk = gof.link.PerformLinker(env)
return lnk
def graph1():
def graph1(): # (x+y) * (x/z)
x = gof.modes.build(Double(1.0, 'x'))
y = gof.modes.build(Double(2.0, 'y'))
z = gof.modes.build(Double(3.0, 'z'))
y = gof.modes.build(Double(3.0, 'y'))
z = gof.modes.build(Double(4.0, 'z'))
o = Mul(Add(x, y).out, Div(x, y).out).out
o = Mul(Add(x, y).out, Div(x, z).out).out
return [x,y,z], [o]
def graph2():
x = gof.modes.build(Double(1.0, 'x'))
y = gof.modes.build(Double(2.0, 'y'))
z = gof.modes.build(Double(3.0, 'z'))
o = Mul(Add(x, y).out, Div(x, y).out).out
return [x,y,z], [o, o, o.owner.inputs[1]]
class T_what:
def test_nothing(self):
pass
class T_Function(unittest.TestCase):
def test_noopt(self):
gi, go = graph1()
p = Function(gi,go)
self.failUnless(p(1.0,3.0,4.0) == 1.0)
class _test_compile(unittest.TestCase):
def test_link_noopt(self):
gi, go = graph1()
fn, i, o = perform_linker(env(gi, go)).make_thunk(True)
fn()
self.failUnless(go[0].data == 1.5)
self.failUnless(go[0].data == 1.0)
def test_link_opt(self):
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
......@@ -94,28 +99,66 @@ class _test_compile(unittest.TestCase):
opt.optimize(e)
fn, i, o = perform_linker(e).make_thunk(True)
fn()
self.failUnless(go[0].data == 6.0)
def test_noopt(self):
gi, go = graph1()
p = Function(gi,go)
self.failUnless(p() == 1.5)
self.failUnless(go[0].data == 16.0)
def test_opt(self):
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph1()
p = Function(gi,go, optimizer=opt)
self.failUnless(p() == 6.0)
self.failUnless(p(1.,3.,4.) == 16.0)
def test_multiout(self):
def graph2():
x = gof.modes.build(Double(1.0, 'x'))
y = gof.modes.build(Double(3.0, 'y'))
z = gof.modes.build(Double(4.0, 'z'))
o = Mul(Add(x, y).out, Div(x, z).out).out
return [x,y,z], [o, o.owner.inputs[1]]
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph2()
p = Function(gi,go, optimizer=opt)
a,b,c = p()
self.failUnless(a == 6.0)
self.failUnless(b == 6.0)
self.failUnless(a is b)
self.failUnless(c == 2.0)
a,b = p(1.,3.,4.)
self.failUnless(a == 16.0)
self.failUnless(b == 4.0)
def test_orphans(self):
gi, go = graph1()
opt = None
p0 = Function(gi[0:0], go, optimizer=opt)
self.failUnless(p0() == 1.0)
p3 = Function(gi,go, optimizer=opt)
p2 = Function(gi[0:2], go, optimizer=opt)
p1 = Function(gi[0:1], go, optimizer=opt)
try:
self.failUnless(p3() == 6.0)
self.fail()
except TypeError, e:
self.failUnless(e[0].split()[0:3] == ['Function','call', 'takes'])
self.failUnless(p3(1.,3.,4.) == 1.0)
self.failUnless(p2(1.,3.) == 1.0)
self.failUnless(p1(1.,) == 1.0)
def test_some_constant_outputs(self):
x = gof.modes.build(Double(1.0, 'x'))
y = gof.modes.build(Double(3.0, 'y'))
z = gof.modes.build(Double(4.0, 'z'))
xy = Mul(x,y).out
zz = Mul(z,z).out
p0 = Function([x,y], [xy, zz])
self.failUnless(p0(1.,3.) == [3.0,16.0])
self.failUnless(p0(1.5,4.) == [6.0,16.0])
self.failUnless(x.data == 1.0)
self.failUnless(y.data == 3.0)
self.failUnless(z.data == 4.0)
p1 = Function([z], [xy, zz],unpack_single=False)
self.failUnless(p1(4.) == [3.0,16.0]) #returns 6.0, 16.10
self.failUnless(p1(5.) == [3.0,25.0])
if __name__ == '__main__':
......
......@@ -14,6 +14,10 @@ def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation"""
return Function(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False)
def _mark_indestructible(results):
for r in results:
r.indestructible = True
class Function:
"""An 'executable' compiled from a graph
......@@ -32,27 +36,96 @@ class Function:
linker - the linker allocated from env
env - The env passed to the linker
"""
def __init__(self,
inputs,
outputs,
features=[],
optimizer=None,
linker_cls=gof.link.PerformLinker,
keep_locals=True):
def __init__(self, inputs, outputs,
features = [],
optimizer = None,
linker_cls = gof.link.PerformLinker,
unpack_single = True,
except_unreachable_input = True,
keep_locals = True):
""" Copy the graph, optimize, and link it.
Parameters:
inputs - a list of results to be this function's inputs
outputs - a list of results to be this function's outputs
features - features to add to the env
optimizer - an optimizer to apply to the copied graph, before linking
linker_cls - a callable that takes an env and returns a Linker
unpack_single - unpack return value lists of length 1
- see Linker.make_function
keep_locals - add the local variables from __init__ to the class
"""
_mark_indestructible(outputs)
if len(inputs) != len(set(inputs)):
raise Exception('duplicate inputs')
if len(outputs) != len(set(outputs)):
raise Exception('duplicate outputs')
#evaluate the orphans, and put these values into the clone of the env
orphans = list(gof.graph.results_and_orphans(inputs, outputs,
except_unreachable_input=except_unreachable_input)[1])
orphan_data = eval_outputs(orphans, unpack_single=False)
#print 'orphans', orphans
#print 'ops', gof.graph.ops(inputs, outputs)
env = gof.env.Env(inputs, outputs, features, consistency_check = True)
#print 'orphans in env', env.orphans()
env = env.clone(clone_inputs=True)
#print 'orphans after clone', env.orphans()
for d, o in zip(orphan_data, env.orphans()):
#print 'assigning orphan value', d
o.data = d
# optimize and link the cloned env
if None is not optimizer:
optimizer.optimize(env)
linker = linker_cls(env)
if keep_locals: # useful flag for debugging
if keep_locals:# useful flag for debugging!
self.__dict__.update(locals())
self.fn = linker.make_function(False)
self.fn = linker.make_function(inplace=True,
unpack_single=unpack_single)
def __call__(self, *args):
return self.fn(*args)
def eval_outputs(outputs,
features = [],
optimizer = None,
linker_cls = gof.link.PerformLinker,
unpack_single = True,
keep_locals = True):
if len(outputs) == 0:
#print 'returning with no inputs'
if unpack_single:
return None
else:
return []
inputs = list(gof.graph.inputs(outputs))
in_data = [i.data for i in inputs if i.data is not None]
#print 'in_data = ', in_data
if len(inputs) != len(in_data):
raise Exception('some input data is unknown')
env = gof.env.Env(inputs, outputs, features, consistency_check = True)
env = env.clone(clone_inputs=True)
_mark_indestructible(env.outputs)
if None is not optimizer:
optimizer.optimize(env)
linker = linker_cls(env)
fn = linker.make_function(inplace=True, unpack_single=unpack_single)
rval = fn(*in_data)
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论