提交 3928937d authored 作者: Frederic Bastien's avatar Frederic Bastien

add the new logic about not having output of theano fct being aliased to the…

add the new logic about not having output of theano fct being aliased to the inputs to the DebugMode to remove false buildbot error.
上级 f3684046
......@@ -13,12 +13,15 @@ from theano.gof.link import WrapLinkerMany, raise_with_op
from theano.gof.cc import OpWiseCLinker, CLinker
from theano.configparser import config, AddConfigVar, IntParam, BoolParam
from theano.compile.function_module import (FunctionMaker,
alias_root,
deep_copy_op,
Function,
infer_reuse_pattern,
SymbolicInput,
SymbolicInputKit,
SymbolicOutput,
Supervisor)
Supervisor,
view_tree_set)
from theano.compile.mode import Mode, register_mode
AddConfigVar('DebugMode.patience',
......@@ -1323,15 +1326,56 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
# Wrap them in In or Out instances if needed.
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update for i in inputs if getattr(i, 'update', False)])
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
assert expanded_inputs == inputs #JB - I added this to make sure we could delete above
# make the env
for i in xrange(mode.stability_patience):
env, additional_outputs, equivalence_tracker = _optcheck_env(expanded_inputs, outputs, accept_inplace)
env.equivalence_tracker = equivalence_tracker
# optimize the env
optimizer(env)
# This loop was inserted to remove aliasing between outputs when they all
# evaluete to the same value. Originally it was OK for outputs to be aliased,
# but some of the outputs can be shared variables, and is not good for shared
# variables to be aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables.
#import pdb;pdb.set_trace()
assert len(inputs) == len(env.inputs)
updated_env_inputs = [env_i for i, env_i in zip(inputs, env.inputs) if getattr(i, 'update', False)]
for i in xrange(len(env.outputs)):
views_of_output_i = set()
view_tree_set(alias_root(env.outputs[i]), views_of_output_i)
copied = False
# do not allow outputs to be aliased
for j in xrange(i+1, len(env.outputs)):
if env.outputs[j] in views_of_output_i:
#import pdb;pdb.set_trace()
env.change_input('output', i, deep_copy_op(env.outputs[i]))
copied = True
break
if not copied:
for input_j in env.inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations
# b) that j'th input is a shared variable that is also being updated
if hasattr(env,'get_destroyers_of') and env.get_destroyers_of(input_j):
continue
if input_j in updated_env_inputs:
continue
if input_j in views_of_output_i:
#import pdb;pdb.set_trace()
env.change_input('output', i, deep_copy_op(env.outputs[i]))
break
if i:
li = env.equivalence_tracker.event_list
l0 = env0.equivalence_tracker.event_list
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论