提交 1cfd8fda authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added strict option to function and gof.link.Filter, disown_inputs option to std_env and function

上级 9af1c094
...@@ -67,12 +67,18 @@ def cloned_env(inputs, outputs): ...@@ -67,12 +67,18 @@ def cloned_env(inputs, outputs):
env = gof.env.Env(inputs, outputs) env = gof.env.Env(inputs, outputs)
return env return env
def std_env(inputs, outputs): def std_env(inputs, outputs, disown_inputs = False):
inputs, outputs = gof.graph.clone(inputs, outputs) inputs, outputs = gof.graph.clone(inputs, outputs)
_mark_indestructible(outputs) _mark_indestructible(outputs)
env = gof.env.Env(inputs, outputs) env = gof.env.Env(inputs, outputs)
env.extend(gof.DestroyHandler()) env.extend(gof.DestroyHandler())
env.extend(gof.ReplaceValidate()) env.extend(gof.ReplaceValidate())
env.validate()
for input in inputs:
input.destroyed_by_user = len(env.destroyers(input)) != 0
if not input.destroyed_by_user and not disown_inputs:
# prevent optimizations from destroying the inputs
input.indestructible = True
return env return env
def std_opt(env): def std_opt(env):
...@@ -88,29 +94,38 @@ predefined_linkers = { ...@@ -88,29 +94,38 @@ predefined_linkers = {
class FunctionFactory: class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False): def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False):
if len(inputs) != len(set(inputs)): if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs" print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs): for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result): if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r) raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs) env = std_env(inputs, outputs, disown_inputs = disown_inputs)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
env.validate() env.validate()
self.env = env self.env = env
linker = predefined_linkers.get(linker, linker) linker = predefined_linkers.get(linker, linker)
if not callable(linker):
raise ValueError("'linker' parameter of FunctionFactory should be a callable that takes an env as argument " \
"or one of ['py', 'c', 'c|py', 'c&py']")
if borrow_outputs: if borrow_outputs:
self.linker = linker(env) self.linker = linker(env)
else: else:
self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs)) self.linker = linker(env, no_recycling = infer_reuse_pattern(env, env.outputs))
def create(self, profiler = None, unpack_single = True):
def create(self, profiler = None, unpack_single = True, strict = 'if_destroyed'):
if strict not in [True, False, 'if_destroyed']:
raise ValueError("'strict' parameter of create should be one of [True, False, 'if_destroyed']")
if profiler is None: if profiler is None:
fn = self.linker.make_function(unpack_single=unpack_single) fn = self.linker.make_function(unpack_single=unpack_single)
else: else:
fn = self.linker.make_function(unpack_single=unpack_single, fn = self.linker.make_function(unpack_single=unpack_single,
profiler=profiler) profiler=profiler)
for env_input, fn_input in zip(self.env.inputs, fn.inputs):
if strict is True or (env_input.destroyed_by_user and strict == 'if_destroyed'):
fn_input.strict = True
return fn return fn
def partial(self, *first, **kwargs): def partial(self, *first, **kwargs):
...@@ -123,15 +138,19 @@ def function(inputs, ...@@ -123,15 +138,19 @@ def function(inputs,
linker = 'py', linker = 'py',
optimizer = std_opt, optimizer = std_opt,
borrow_outputs = False, borrow_outputs = False,
disown_inputs = False,
profiler = None, profiler = None,
unpack_single = True): unpack_single = True,
strict = 'if_destroyed'):
ff = FunctionFactory(inputs, ff = FunctionFactory(inputs,
outputs, outputs,
linker = linker, linker = linker,
optimizer = optimizer, optimizer = optimizer,
borrow_outputs = borrow_outputs,) borrow_outputs = borrow_outputs,
disown_inputs = disown_inputs)
return ff.create(profiler = profiler, return ff.create(profiler = profiler,
unpack_single = unpack_single) unpack_single = unpack_single,
strict = strict)
def eval_outputs(outputs, **kwargs): def eval_outputs(outputs, **kwargs):
......
...@@ -260,9 +260,9 @@ class Env(object): #(graph.Graph): ...@@ -260,9 +260,9 @@ class Env(object): #(graph.Graph):
self._features.remove(feature) self._features.remove(feature)
except: except:
return return
deattach = getattr(feature, 'on_deattach', None) detach = getattr(feature, 'on_detach', None)
if deattach is not None: if detach is not None:
deattach(self) detach(self)
### callback utils ### ### callback utils ###
......
...@@ -110,16 +110,20 @@ class Linker: ...@@ -110,16 +110,20 @@ class Linker:
class Filter(object): class Filter(object):
def __init__(self, type, storage, readonly = False): def __init__(self, type, storage, readonly = False, strict = False):
self.type = type self.type = type
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict
def __get(self): def __get(self):
return self.storage[0] return self.storage[0]
def __set(self, value): def __set(self, value):
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage.") raise Exception("Cannot set readonly storage.")
self.storage[0] = self.type.filter(value) if self.strict:
self.storage[0] = self.type.filter(value, strict = True)
else:
self.storage[0] = self.type.filter(value)
data = property(__get, __set) data = property(__get, __set)
def __str__(self): def __str__(self):
return "<" + str(self.storage[0]) + ">" return "<" + str(self.storage[0]) + ">"
......
...@@ -11,7 +11,7 @@ class Bookkeeper: ...@@ -11,7 +11,7 @@ class Bookkeeper:
for node in graph.io_toposort(env.inputs, env.outputs): for node in graph.io_toposort(env.inputs, env.outputs):
self.on_import(env, node) self.on_import(env, node)
def on_deattach(self, env): def on_detach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs): for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node) self.on_prune(env, node)
...@@ -22,7 +22,7 @@ class Toposorter: ...@@ -22,7 +22,7 @@ class Toposorter:
raise Exception("Toposorter feature is already present or in conflict with another plugin.") raise Exception("Toposorter feature is already present or in conflict with another plugin.")
env.toposort = partial(self.toposort, env) env.toposort = partial(self.toposort, env)
def on_deattach(self, env): def on_detach(self, env):
del env.toposort del env.toposort
def toposort(self, env): def toposort(self, env):
...@@ -73,7 +73,7 @@ class History: ...@@ -73,7 +73,7 @@ class History:
env.checkpoint = lambda: len(self.history[env]) env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env) env.revert = partial(self.revert, env)
def on_deattach(self, env): def on_detach(self, env):
del env.checkpoint del env.checkpoint
del env.revert del env.revert
del self.history[env] del self.history[env]
...@@ -112,7 +112,7 @@ class Validator: ...@@ -112,7 +112,7 @@ class Validator:
return False return False
env.consistent = consistent env.consistent = consistent
def on_deattach(self, env): def on_detach(self, env):
del env.validate del env.validate
del env.consistent del env.consistent
...@@ -128,9 +128,9 @@ class ReplaceValidate(History, Validator): ...@@ -128,9 +128,9 @@ class ReplaceValidate(History, Validator):
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
def on_deattach(self, env): def on_detach(self, env):
History.on_deattach(self, env) History.on_detach(self, env)
Validator.on_deattach(self, env) Validator.on_detach(self, env)
del env.replace_validate del env.replace_validate
del env.replace_all_validate del env.replace_all_validate
...@@ -162,12 +162,12 @@ class NodeFinder(dict, Bookkeeper): ...@@ -162,12 +162,12 @@ class NodeFinder(dict, Bookkeeper):
env.get_nodes = partial(self.query, env) env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env) Bookkeeper.on_attach(self, env)
def on_deattach(self, env): def on_detach(self, env):
if self.env is not env: if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.") raise Exception("This NodeFinder instance was not attached to the provided env.")
self.env = None self.env = None
del env.get_nodes del env.get_nodes
Bookkeeper.on_deattach(self, env) Bookkeeper.on_detach(self, env)
def on_import(self, env, node): def on_import(self, env, node):
try: try:
...@@ -205,9 +205,9 @@ class PrintListener(object): ...@@ -205,9 +205,9 @@ class PrintListener(object):
if self.active: if self.active:
print "-- attaching to: ", env print "-- attaching to: ", env
def on_deattach(self, env): def on_detach(self, env):
if self.active: if self.active:
print "-- deattaching from: ", env print "-- detaching from: ", env
def on_import(self, env, node): def on_import(self, env, node):
if self.active: if self.active:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论