提交 1fc4a77b authored 作者: James Bergstra's avatar James Bergstra

added OutputGuard sanity check to AddDestroyHandler to ensure that outputs are…

added OutputGuard sanity check to AddDestroyHandler to ensure that outputs are protected from destructive operations
上级 be10831b
...@@ -74,9 +74,39 @@ def register_optimizer(name, opt): ...@@ -74,9 +74,39 @@ def register_optimizer(name, opt):
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
class OutputGuard(gof.Op):
destroy_map = {0:[0]}
view_map = {0:[0]}
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def perform(self, node, (x,), (z,)):
z[0] = x
def __str__(self):
return '%s' % self.__class__.__name__
class AddDestroyHandler(gof.Optimizer): class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions:
1) it has a 'requirement' of the destroyhandler. This means that the env will include it
as a feature for this optimization, and keep this feature enabled for subsequent
optimizations. All optimizations that work inplace on any of their inputs must run *after*
this optimization to ensure that the DestroyHandler has been included in the env.
2) It tries to replace each output with an Op that purports to destroy it (but it won't I
promise). If this replacement succeeds it means that there is a bug in theano. It should
not be possible to destroy outputs.
"""
def apply(self, env): def apply(self, env):
pass output_guard = OutputGuard()
for o in env.outputs:
try:
env.replace_validate(o, output_guard(o), reason='output_guard')
warning("Output variable %s required output_guard,"
" how was this output left unprotected against destructive operations?"
% o)
except gof.InconsistencyError:
#this output is already impossible to destroy. no guard necessary
pass
def add_requirements(self, env): def add_requirements(self, env):
super(AddDestroyHandler, self).add_requirements(env) super(AddDestroyHandler, self).add_requirements(env)
env.extend(gof.DestroyHandler()) env.extend(gof.DestroyHandler())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论