changed optimizer calling style

上级 b06dcd20
......@@ -103,7 +103,7 @@ class T_Function(unittest.TestCase):
def test_opt(self):
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph1()
p = Function(gi,go, optimizer=opt)
p = Function(gi,go, optimizer=opt.optimize)
self.failUnless(p(1.,3.,4.) == 16.0)
def test_multiout(self):
......@@ -116,7 +116,7 @@ class T_Function(unittest.TestCase):
return [x,y,z], [o, o.owner.inputs[1]]
opt = gof.opt.PatternOptimizer((Div, '1', '2'), (Div, '2', '1'))
gi, go = graph2()
p = Function(gi,go, optimizer=opt)
p = Function(gi,go, optimizer=opt.optimize)
a,b = p(1.,3.,4.)
self.failUnless(a == 16.0)
self.failUnless(b == 4.0)
......@@ -124,13 +124,13 @@ class T_Function(unittest.TestCase):
def test_orphans(self):
gi, go = graph1()
opt = None
p0 = Function(gi[0:0], go, optimizer=opt)
p0 = Function(gi[0:0], go)
self.failUnless(p0() == 1.0)
p3 = Function(gi,go, optimizer=opt)
p2 = Function(gi[0:2], go, optimizer=opt)
p1 = Function(gi[0:1], go, optimizer=opt)
p3 = Function(gi,go)
p2 = Function(gi[0:2], go)
p1 = Function(gi[0:1], go)
try:
self.failUnless(p3() == 6.0)
self.fail()
......
......@@ -4,15 +4,15 @@ import gof
#TODO: put together some default optimizations (TRAC #67)
_optimizations = None
def exec_py_opt(inputs, outputs, features=[]):
"""Return an optimized graph running purely python implementations"""
return Function(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker, False)
exec_py_opt.optimizer = None
def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation"""
return Function(intputs, outputs, features, _optimizations, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker, False)
exec_opt.optimizer = None
def _mark_indestructible(results):
for r in results:
......@@ -85,7 +85,7 @@ class Function:
# optimize and link the cloned env
if None is not optimizer:
optimizer.optimize(env)
optimizer(env)
linker = linker_cls(env)
if keep_locals:# useful flag for debugging!
......@@ -122,7 +122,7 @@ def eval_outputs(outputs,
_mark_indestructible(env.outputs)
if None is not optimizer:
optimizer.optimize(env)
optimizer(env)
linker = linker_cls(env)
fn = linker.make_function(inplace=True, unpack_single=unpack_single)
rval = fn(*in_data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论