提交 2c813862 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

modified compile to use the new linker interface

上级 61d10d72
...@@ -3,17 +3,18 @@ ...@@ -3,17 +3,18 @@
import numpy import numpy
import gof import gof
import sys import sys
from copy import copy
#TODO: put together some default optimizations (TRAC #67) #TODO: put together some default optimizations (TRAC #67)
def exec_py_opt(inputs, outputs, features=[]): def exec_py_opt(inputs, outputs, features=[]):
"""Return an optimized graph running purely python implementations""" """Return an optimized graph running purely python implementations"""
return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker, False) return Function(intputs, outputs, features, exec_py_opt.optimizer, gof.link.PerformLinker(), False)
exec_py_opt.optimizer = None exec_py_opt.optimizer = None
def exec_opt(inputs, outputs, features=[]): def exec_opt(inputs, outputs, features=[]):
"""Return a fast implementation""" """Return a fast implementation"""
return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker, False) return Function(intputs, outputs, features, exec_opt.optimizer, gof.link.PerformLinker(), False)
exec_opt.optimizer = None exec_opt.optimizer = None
class _DefaultOptimizer(object): class _DefaultOptimizer(object):
...@@ -28,18 +29,20 @@ def _mark_indestructible(results): ...@@ -28,18 +29,20 @@ def _mark_indestructible(results):
for r in results: for r in results:
r.tag.indestructible = True r.tag.indestructible = True
def linker_cls_python_and_c(env, **kwargs): # def linker_cls_python_and_c(env, **kwargs):
"""Use this as the linker_cls argument to Function.__init__ to compare # """Use this as the linker_cls argument to Function.__init__ to compare
python and C implementations""" # python and C implementations"""
def checker(x, y):
x, y = x[0], y[0] def check_equal(x, y):
if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray): x, y = x[0], y[0]
if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10): if isinstance(x, numpy.ndarray) or isinstance(y, numpy.ndarray):
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) if x.dtype != y.dtype or x.shape != y.shape or numpy.any(abs(x - y) > 1e-10):
else: raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
if x != y: else:
if x != y:
raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y}) raise Exception("Output mismatch.", {'performlinker': x, 'clinker': y})
return gof.DualLinker(env, checker, **kwargs)
# return gof.DualLinker(checker, **kwargs).accept(env)
def infer_reuse_pattern(env, outputs_to_disown): def infer_reuse_pattern(env, outputs_to_disown):
...@@ -86,10 +89,10 @@ def std_opt(env): ...@@ -86,10 +89,10 @@ def std_opt(env):
predefined_linkers = { predefined_linkers = {
'py' : gof.link.PerformLinker, 'py' : gof.PerformLinker(),
'c' : gof.cc.CLinker, 'c' : gof.CLinker(),
'c|py' : gof.cc.OpWiseCLinker, 'c|py' : gof.OpWiseCLinker(),
'c&py' : linker_cls_python_and_c 'c&py' : gof.DualLinker(checker = check_equal)
} }
class FunctionFactory: class FunctionFactory:
...@@ -105,14 +108,14 @@ class FunctionFactory: ...@@ -105,14 +108,14 @@ class FunctionFactory:
optimizer(env) optimizer(env)
env.validate() env.validate()
self.env = env self.env = env
linker = predefined_linkers.get(linker, linker) linker = copy(predefined_linkers.get(linker, linker))
if not callable(linker): if not hasattr(linker, 'accept'):
raise ValueError("'linker' parameter of FunctionFactory should be a callable that takes an env as argument " \ raise ValueError("'linker' parameter of FunctionFactory should be a Linker with an accept method " \
"or one of ['py', 'c', 'c|py', 'c&py']") "or one of ['py', 'c', 'c|py', 'c&py']")
if borrow_outputs: if borrow_outputs:
self.linker = linker(env) self.linker = linker.accept(env)
else: else:
self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs)) self.linker = linker.accept(env, no_recycling = infer_reuse_pattern(env, env.outputs))
def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'): def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论