提交 705dfaaf authored 作者: Olivier Breuleux's avatar Olivier Breuleux

use_destroy_handler arg to function

上级 f4617baf
...@@ -70,15 +70,17 @@ def cloned_env(inputs, outputs): ...@@ -70,15 +70,17 @@ def cloned_env(inputs, outputs):
env = gof.env.Env(inputs, outputs) env = gof.env.Env(inputs, outputs)
return env return env
def std_env(inputs, outputs, disown_inputs = False): def std_env(inputs, outputs, disown_inputs = False,
use_destroy_handler = True):
inputs, outputs = gof.graph.clone(inputs, outputs) inputs, outputs = gof.graph.clone(inputs, outputs)
_mark_indestructible(outputs) _mark_indestructible(outputs)
env = gof.env.Env(inputs, outputs) env = gof.env.Env(inputs, outputs)
env.extend(gof.DestroyHandler()) if use_destroy_handler:
env.extend(gof.DestroyHandler())
env.extend(gof.ReplaceValidate()) env.extend(gof.ReplaceValidate())
env.validate() env.validate()
for input in inputs: for input in inputs:
input.destroyed_by_user = len(env.destroyers(input)) != 0 input.destroyed_by_user = use_destroy_handler and len(env.destroyers(input)) != 0
if not input.destroyed_by_user and not disown_inputs: if not input.destroyed_by_user and not disown_inputs:
# prevent optimizations from destroying the inputs # prevent optimizations from destroying the inputs
input.tag.indestructible = True input.tag.indestructible = True
...@@ -97,13 +99,15 @@ predefined_linkers = { ...@@ -97,13 +99,15 @@ predefined_linkers = {
class FunctionFactory: class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False): def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False,
use_destroy_handler = True):
if len(inputs) != len(set(inputs)): if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs" print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs): for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result): if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r) raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs, disown_inputs = disown_inputs) env = std_env(inputs, outputs, disown_inputs = disown_inputs,
use_destroy_handler = use_destroy_handler)
if None is not optimizer: if None is not optimizer:
optimizer(env) optimizer(env)
env.validate() env.validate()
...@@ -144,13 +148,15 @@ def function(inputs, ...@@ -144,13 +148,15 @@ def function(inputs,
disown_inputs = False, disown_inputs = False,
profiler = None, profiler = None,
unpack_single = True, unpack_single = True,
strict = 'if_destroyed'): strict = 'if_destroyed',
use_destroy_handler = True):
ff = FunctionFactory(inputs, ff = FunctionFactory(inputs,
outputs, outputs,
linker = linker, linker = linker,
optimizer = optimizer, optimizer = optimizer,
borrow_outputs = borrow_outputs, borrow_outputs = borrow_outputs,
disown_inputs = disown_inputs) disown_inputs = disown_inputs,
use_destroy_handler = use_destroy_handler)
return ff.create(profiler = profiler, return ff.create(profiler = profiler,
unpack_single = unpack_single, unpack_single = unpack_single,
strict = strict) strict = strict)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论