提交 2c813862 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

modified compile to use the new linker interface

上级 61d10d72
......@@ -3,17 +3,18 @@
import numpy
import gof
import sys
from copy import copy
#TODO: put together some default optimizations (TRAC #67)
def exec_py_opt(inputs, outputs, features=[]):
"""Return an optimized graph running purely python implementations"""
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker(), False)
exec_py_opt.optimizer = None
def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation"""
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker, False)
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker(), False)
exec_opt.optimizer = None
class _DefaultOptimizer(object):
......@@ -28,18 +29,20 @@ def _mark_indestructible(results):
for r in results:
r.tag.indestructible = True
def linker_cls_python_and_c(env, **kwargs):
"""Use this as the linker_cls argument to Function.__init__ to compare
python and C implementations"""
def checker(x, y):
x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
else:
if x != y:
# def linker_cls_python_and_c(env, **kwargs):
# """Use this as the linker_cls argument to Function.__init__ to compare
# python and C implementations"""
def check_equal(x, y):
x, y = x[0], y[0]
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
else:
if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
return gof.DualLinker(env, checker, **kwargs)
# return gof.DualLinker(checker, **kwargs).accept(env)
def infer_reuse_pattern(env, outputs_to_disown):
......@@ -86,10 +89,10 @@ def std_opt(env):
predefined_linkers = {
'py' : gof.link.PerformLinker,
'c' : gof.cc.CLinker,
'c|py' : gof.cc.OpWiseCLinker,
'c&py' : linker_cls_python_and_c
'py' : gof.PerformLinker(),
'c' : gof.CLinker(),
'c|py' : gof.OpWiseCLinker(),
'c&py' : gof.DualLinker(checker = check_equal)
}
class FunctionFactory:
......@@ -105,14 +108,14 @@ class FunctionFactory:
optimizer(env)
env.validate()
self.env = env
linker = predefined_linkers.get(linker, linker)
if not callable(linker):
raise ValueError("'linker' parameter of FunctionFactory should be a callable that takes an env as argument " \
linker = copy(predefined_linkers.get(linker, linker))
if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
"or one of ['py', 'c', 'c|py', 'c&py']")
if borrow_outputs:
self.linker = linker(env)
self.linker = linker.accept(env)
else:
self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs))
self.linker = linker.accept(env, no_recycling = infer_reuse_pattern(env, env.outputs))
def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论