提交 5960a4ea authored 作者: James Bergstra's avatar James Bergstra

function_module - extended the logic that un-aliases input/output variables

上级 36b147ea
......@@ -751,13 +751,19 @@ class FunctionMaker(object):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
# Wrap them in In or Out instances if needed.
#import pudb; pudb.set_trace()
inputs, outputs = map(self.wrap_in, inputs), map(self.wrap_out, outputs)
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update for i in inputs if getattr(i, 'update', False)])
_inputs = gof.graph.inputs([o.variable for o in outputs] + [i.update
for i in inputs if getattr(i, 'update', False)])
#TODO: REMOVE THIS CRUFT - it's complicated for SymbolicInputKits
indices = [[input] + self.expand_in(input, _inputs) for input in inputs]
expanded_inputs = reduce(list.__add__, [list(z) for x, y, z in indices], [])
assert expanded_inputs == inputs #JB - I added this to make sure we could delete above
# make the env
# make the env (copies the graph, creates NEW INPUT AND OUTPUT VARIABLES)
env, additional_outputs = std_env(expanded_inputs, outputs, accept_inplace)
self.env = env
......@@ -774,12 +780,34 @@ class FunctionMaker(object):
# but some of the outputs can be shared variables, and is not good for shared
# variables to be aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables.
assert len(inputs) == len(env.inputs)
updated_env_inputs = [env_i for i, env_i in zip(inputs, env.inputs) if getattr(i, 'update', False)]
for i in xrange(len(env.outputs)):
views = set()
view_tree_set(alias_root(env.outputs[i]), views)
views_of_output_i = set()
view_tree_set(alias_root(env.outputs[i]), views_of_output_i)
copied = False
# do not allow outputs to be aliased
for j in xrange(i+1, len(env.outputs)):
if env.outputs[j] in views:
env.change_input('output', j, deep_copy_op(env.outputs[j]))
if env.outputs[j] in views_of_output_i:
env.change_input('output', i, deep_copy_op(env.outputs[i]))
copied = True
break
if not copied:
for input_j in env.inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations
# b) that j'th input is a shared variable that is also being updated
if hasattr(env,'get_destroyers_of') and env.get_destroyers_of(input_j):
continue
if input_j in updated_env_inputs:
continue
if input_j in views_of_output_i:
env.change_input('output', i, deep_copy_op(env.outputs[i]))
break
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论