提交 1cfd8fda authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added strict option to function and gof.link.Filter, disown_inputs option to std_env and function

上级 9af1c094
......@@ -67,12 +67,18 @@ def cloned_env(inputs, outputs):
env = gof.env.Env(inputs, outputs)
return env
def std_env(inputs, outputs):
def std_env(inputs, outputs, disown_inputs = False):
inputs, outputs = gof.graph.clone(inputs, outputs)
_mark_indestructible(outputs)
env = gof.env.Env(inputs, outputs)
env.extend(gof.DestroyHandler())
env.extend(gof.ReplaceValidate())
env.validate()
for input in inputs:
input.destroyed_by_user = len(env.destroyers(input)) != 0
if not input.destroyed_by_user and not disown_inputs:
# prevent optimizations from destroying the inputs
input.indestructible = True
return env
def std_opt(env):
......@@ -88,29 +94,38 @@ predefined_linkers = {
class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False):
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False):
if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs)
env = std_env(inputs, outputs, disown_inputs = disown_inputs)
if None is not optimizer:
optimizer(env)
env.validate()
self.env = env
linker = predefined_linkers.get(linker, linker)
if not callable(linker):
raise ValueError("'linker' parameter of FunctionFactory should be a callable that takes an env as argument " \
"or one of ['py', 'c', 'c|py', 'c&py']")
if borrow_outputs:
self.linker = linker(env)
else:
self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs))
def create(self, profiler = None, unpack_single = True):
def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
if strict not in [True, False, 'if_destroyed']:
raise ValueError("'strict' parameter of create should be one of [True, False, 'if_destroyed']")
if profiler is None:
fn = self.linker.make_function(unpack_single=unpack_single)
else:
fn = self.linker.make_function(unpack_single=unpack_single,
profiler=profiler)
for env_input, fn_input in zip(self.env.inputs, fn.inputs):
if strict is True or (env_input.destroyed_by_user and strict == 'if_destroyed'):
fn_input.strict = True
return fn
def partial(self, *first, **kwargs):
......@@ -123,15 +138,19 @@ def function(inputs,
linker = 'py',
optimizer = std_opt,
borrow_outputs = False,
disown_inputs = False,
profiler = None,
unpack_single = True):
unpack_single = True,
strict = 'if_destroyed'):
ff = FunctionFactory(inputs,
outputs,
linker = linker,
optimizer = optimizer,
borrow_outputs = borrow_outputs,)
borrow_outputs = borrow_outputs,
disown_inputs = disown_inputs)
return ff.create(profiler = profiler,
unpack_single = unpack_single)
unpack_single = unpack_single,
strict = strict)
def eval_outputs(outputs, **kwargs):
......
......@@ -260,9 +260,9 @@ class Env(object): #(graph.Graph):
self._features.remove(feature)
except:
return
deattach = getattr(feature, 'on_deattach', None)
if deattach is not None:
deattach(self)
detach = getattr(feature, 'on_detach', None)
if detach is not None:
detach(self)
### callback utils ###
......
......@@ -110,15 +110,19 @@ class Linker:
class Filter(object):
def __init__(self, type, storage, readonly = False):
def __init__(self, type, storage, readonly = False, strict = False):
self.type = type
self.storage = storage
self.readonly = readonly
self.strict = strict
def __get(self):
return self.storage[0]
def __set(self, value):
if self.readonly:
raise Exception("Cannot set readonly storage.")
if self.strict:
self.storage[0] = self.type.filter(value, strict = True)
else:
self.storage[0] = self.type.filter(value)
data = property(__get, __set)
def __str__(self):
......
......@@ -11,7 +11,7 @@ class Bookkeeper:
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_import(env, node)
def on_deattach(self, env):
def on_detach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node)
......@@ -22,7 +22,7 @@ class Toposorter:
raise Exception("Toposorter feature is already present or in conflict with another plugin.")
env.toposort = partial(self.toposort, env)
def on_deattach(self, env):
def on_detach(self, env):
del env.toposort
def toposort(self, env):
......@@ -73,7 +73,7 @@ class History:
env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env)
def on_deattach(self, env):
def on_detach(self, env):
del env.checkpoint
del env.revert
del self.history[env]
......@@ -112,7 +112,7 @@ class Validator:
return False
env.consistent = consistent
def on_deattach(self, env):
def on_detach(self, env):
del env.validate
del env.consistent
......@@ -128,9 +128,9 @@ class ReplaceValidate(History, Validator):
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
def on_deattach(self, env):
History.on_deattach(self, env)
Validator.on_deattach(self, env)
def on_detach(self, env):
History.on_detach(self, env)
Validator.on_detach(self, env)
del env.replace_validate
del env.replace_all_validate
......@@ -162,12 +162,12 @@ class NodeFinder(dict, Bookkeeper):
env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env)
def on_deattach(self, env):
def on_detach(self, env):
if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.")
self.env = None
del env.get_nodes
Bookkeeper.on_deattach(self, env)
Bookkeeper.on_detach(self, env)
def on_import(self, env, node):
try:
......@@ -205,9 +205,9 @@ class PrintListener(object):
if self.active:
print "-- attaching to: ", env
def on_deattach(self, env):
def on_detach(self, env):
if self.active:
print "-- deattaching from: ", env
print "-- detaching from: ", env
def on_import(self, env, node):
if self.active:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论