提交 29b1ff7b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor updates-related code in aesara.compile.function.types

上级 a22da800
...@@ -183,12 +183,14 @@ def std_fgraph( ...@@ -183,12 +183,14 @@ def std_fgraph(
update_mapping[out_idx] = idx update_mapping[out_idx] = idx
out_idx += 1 out_idx += 1
if fgraph: found_updates = []
if fgraph.update_mapping is None: if fgraph and fgraph.update_mapping is None:
fgraph.update_mapping = update_mapping fgraph.update_mapping = update_mapping
for update in updates: for update in updates:
fgraph.add_output(update, reason="std_fgraph") fgraph.add_output(update, reason="std_fgraph")
else:
found_updates.extend(map(SymbolicOutput, updates))
elif fgraph is None:
input_vars = [] input_vars = []
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`), # If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
...@@ -209,7 +211,7 @@ def std_fgraph( ...@@ -209,7 +211,7 @@ def std_fgraph(
clone=clone, clone=clone,
) )
additional_outputs = list(map(SymbolicOutput, updates)) found_updates.extend(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if node.op.destroy_map: if node.op.destroy_map:
...@@ -235,7 +237,7 @@ def std_fgraph( ...@@ -235,7 +237,7 @@ def std_fgraph(
for feature in features: for feature in features:
fgraph.attach_feature(feature()) fgraph.attach_feature(feature())
return fgraph, additional_outputs return fgraph, found_updates
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
...@@ -675,7 +677,11 @@ class Function: ...@@ -675,7 +677,11 @@ class Function:
# Re initialize Outs and swap update and variable in Ins # Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker.check_unused_inputs() # By doing this, we can pass FunctionMaker.check_unused_inputs()
outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)])) if delete_updates:
outs = list(map(SymbolicOutput, fg_cpy.outputs[: len(maker.outputs)]))
else:
outs = list(map(SymbolicOutput, fg_cpy.outputs))
for out_ori, out_cpy in zip(maker.outputs, outs): for out_ori, out_cpy in zip(maker.outputs, outs):
out_cpy.borrow = out_ori.borrow out_cpy.borrow = out_ori.borrow
...@@ -712,14 +718,17 @@ class Function: ...@@ -712,14 +718,17 @@ class Function:
fg_cpy.replace(in_v, swap_sv, reason="Swap SV") fg_cpy.replace(in_v, swap_sv, reason="Swap SV")
# Delete update if needed # Delete update if needed
update_i = len(outs) rev_update_mapping = {v: k for k, v in fg_cpy.update_mapping.items()}
for i, in_var in zip(ins, fg_cpy.inputs): for n, (inp, in_var) in enumerate(zip(ins, fg_cpy.inputs)):
i.variable = in_var inp.variable = in_var
if not delete_updates and i.update is not None: if not delete_updates and inp.update is not None:
i.update = fg_cpy.outputs[update_i] out_idx = rev_update_mapping[n]
update_i += 1 inp.update = fg_cpy.outputs[out_idx]
else: else:
i.update = None inp.update = None
if delete_updates:
fg_cpy.update_mapping = {}
# Construct new storage_map that map new variable to old storage, # Construct new storage_map that map new variable to old storage,
# so that the ensuing function shares storage with the original one # so that the ensuing function shares storage with the original one
...@@ -1518,7 +1527,7 @@ class FunctionMaker: ...@@ -1518,7 +1527,7 @@ class FunctionMaker:
indices = [[input, None, [input]] for input in inputs] indices = [[input, None, [input]] for input in inputs]
fgraph, additional_outputs = std_fgraph( fgraph, found_updates = std_fgraph(
inputs, outputs, accept_inplace, fgraph=fgraph inputs, outputs, accept_inplace, fgraph=fgraph
) )
...@@ -1531,16 +1540,16 @@ class FunctionMaker: ...@@ -1531,16 +1540,16 @@ class FunctionMaker:
if not no_fgraph_prep: if not no_fgraph_prep:
self.prepare_fgraph( self.prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile inputs, outputs, found_updates, fgraph, optimizer, linker, profile
) )
# the 'no_borrow' outputs are the ones for which that we can't assert len(fgraph.outputs) == len(outputs + found_updates)
# return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs)
# The 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
no_borrow = [ no_borrow = [
output output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs) for output, spec in zip(fgraph.outputs, outputs + found_updates)
if not spec.borrow if not spec.borrow
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论