提交 29b1ff7b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor updates-related code in aesara.compile.function.types

上级 a22da800
......@@ -183,12 +183,14 @@ def std_fgraph(
update_mapping[out_idx] = idx
out_idx += 1
if fgraph:
if fgraph.update_mapping is None:
fgraph.update_mapping = update_mapping
for update in updates:
fgraph.add_output(update, reason="std_fgraph")
else:
found_updates = []
if fgraph and fgraph.update_mapping is None:
fgraph.update_mapping = update_mapping
for update in updates:
fgraph.add_output(update, reason="std_fgraph")
found_updates.extend(map(SymbolicOutput, updates))
elif fgraph is None:
input_vars = []
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
......@@ -209,7 +211,7 @@ def std_fgraph(
clone=clone,
)
additional_outputs = list(map(SymbolicOutput, updates))
found_updates.extend(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes:
if node.op.destroy_map:
......@@ -235,7 +237,7 @@ def std_fgraph(
for feature in features:
fgraph.attach_feature(feature())
return fgraph, additional_outputs
return fgraph, found_updates
class AliasedMemoryError(Exception):
......@@ -675,7 +677,11 @@ class Function:
# Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker.check_unused_inputs()
outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)]))
if delete_updates:
outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)]))
else:
outs = list(map(SymbolicOutput, fg_cpy.outputs))
for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow
......@@ -712,14 +718,17 @@ class Function:
fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# Delete update if needed
update_i = len(outs)
for i, in_var in zip(ins, fg_cpy.inputs):
i.variable = in_var
if not delete_updates and i.update is not None:
i.update = fg_cpy.outputs[update_i]
update_i += 1
rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
inp.variable = in_var
if not delete_updates and inp.update is not None:
out_idx = rev_update_mapping[n]
inp.update = fg_cpy.outputs[out_idx]
else:
i.update = None
inp.update = None
if delete_updates:
fg_cpy.update_mapping = {}
# Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one
......@@ -1518,7 +1527,7 @@ class FunctionMaker:
indices = [[input, None, [input]] for input in inputs]
fgraph, additional_outputs = std_fgraph(
fgraph, found_updates = std_fgraph(
inputs, outputs, accept_inplace, fgraph=fgraph
)
......@@ -1531,16 +1540,16 @@ class FunctionMaker:
if not no_fgraph_prep:
self.prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
inputs, outputs, found_updates, fgraph, optimizer, linker, profile
)
# the 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs)
assert len(fgraph.outputs) == len(outputs + found_updates)
# The 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
for output, spec in zip(fgraph.outputs, outputs + found_updates)
if not spec.borrow
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论