提交 b0a1d33c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add view_map and destroy_map variables and docstrings to Op

上级 cae78759
...@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError): ...@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError):
print(" node:", self.node, file=sio) print(" node:", self.node, file=sio)
print(" perform:", self.perform, file=sio) print(" perform:", self.perform, file=sio)
print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio) print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio)
print(" destroy_map:", getattr(self.node.op, "destroy_map", {}), file=sio) print(" destroy_map:", self.node.op.destroy_map, file=sio)
print(" changed input idx:", self.idx, file=sio) print(" changed input idx:", self.idx, file=sio)
print(" changed input type:", self.node.inputs[self.idx].type, file=sio) print(" changed input type:", self.node.inputs[self.idx].type, file=sio)
print(" repr (old val):", repr(self.old_val), file=sio) print(" repr (old val):", repr(self.old_val), file=sio)
...@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError): ...@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError):
print(" node:", self.node, file=sio) print(" node:", self.node, file=sio)
print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio) print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio)
print(" node.outputs:", [(str(i), id(i)) for i in self.node.outputs], file=sio) print(" node.outputs:", [(str(i), id(i)) for i in self.node.outputs], file=sio)
print(" view_map:", getattr(self.node.op, "view_map", {}), file=sio) print(" view_map:", self.node.op.view_map, file=sio)
print(" destroy_map:", getattr(self.node.op, "destroy_map", {}), file=sio) print(" destroy_map:", self.node.op.destroy_map, file=sio)
print(" aliased output:", self.output_idx, file=sio) print(" aliased output:", self.output_idx, file=sio)
print(" aliased output storage:", self.out_storage, file=sio) print(" aliased output storage:", self.out_storage, file=sio)
if self.in_alias_idx: if self.in_alias_idx:
...@@ -554,12 +554,12 @@ def debugprint( ...@@ -554,12 +554,12 @@ def debugprint(
r_name = "" r_name = ""
if print_destroy_map: if print_destroy_map:
destroy_map_str = str(getattr(r.owner.op, "destroy_map", "")) destroy_map_str = str(r.owner.op.destroy_map)
else: else:
destroy_map_str = "" destroy_map_str = ""
if print_view_map: if print_view_map:
view_map_str = str(getattr(r.owner.op, "view_map", "")) view_map_str = str(r.owner.op.view_map)
else: else:
view_map_str = "" view_map_str = ""
if destroy_map_str and destroy_map_str != "{}": if destroy_map_str and destroy_map_str != "{}":
...@@ -742,13 +742,13 @@ def _check_inputs( ...@@ -742,13 +742,13 @@ def _check_inputs(
""" """
destroyed_idx_list = [] destroyed_idx_list = []
destroy_map = getattr(node.op, "destroy_map", {}) destroy_map = node.op.destroy_map
for o_pos, i_pos_list in destroy_map.items(): for o_pos, i_pos_list in destroy_map.items():
destroyed_idx_list.extend(i_pos_list) destroyed_idx_list.extend(i_pos_list)
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list] destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
actually_inplace_outputs = [] actually_inplace_outputs = []
dmap = getattr(node.op, "destroy_map", {}) dmap = node.op.destroy_map
for oo, ii in dmap.items(): for oo, ii in dmap.items():
var = node.outputs[oo] var = node.outputs[oo]
out_var = storage_map[var][0] out_var = storage_map[var][0]
...@@ -769,7 +769,7 @@ def _check_inputs( ...@@ -769,7 +769,7 @@ def _check_inputs(
f"as destroyed was not changed for node '{node}'" f"as destroyed was not changed for node '{node}'"
) )
vmap = getattr(node.op, "view_map", {}) vmap = node.op.view_map
for oo, ii in vmap.items(): for oo, ii in vmap.items():
var = node.outputs[oo] var = node.outputs[oo]
out_var = storage_map[var][0] out_var = storage_map[var][0]
...@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map): ...@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map):
outstorage = storage_map[onode][0] outstorage = storage_map[onode][0]
# first find out which input it aliases # first find out which input it aliases
view_map = getattr(node.op, "view_map", {}) view_map = node.op.view_map
destroy_map = getattr(node.op, "destroy_map", {}) destroy_map = node.op.destroy_map
# In theory, aesara's view_map only allows for 1 output to # In theory, aesara's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in # alias 1 input. Checking for multiple aliases just in
...@@ -1395,8 +1395,8 @@ def _check_preallocated_output( ...@@ -1395,8 +1395,8 @@ def _check_preallocated_output(
# Set of inputs that are marked as destroyed or viewed # Set of inputs that are marked as destroyed or viewed
aliased_inputs = set() aliased_inputs = set()
dmap = getattr(node.op, "destroy_map", {}) dmap = node.op.destroy_map
vmap = getattr(node.op, "view_map", {}) vmap = node.op.view_map
for i, r in enumerate(node.inputs): for i, r in enumerate(node.inputs):
if any(i in v for v in chain(dmap.values(), vmap.values())): if any(i in v for v in chain(dmap.values(), vmap.values())):
aliased_inputs.add(r) aliased_inputs.add(r)
...@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker): ...@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker):
clobber = True clobber = True
if thunk_py: if thunk_py:
dmap = getattr(node.op, "destroy_map", {}) dmap = node.op.destroy_map
vmap = getattr(node.op, "view_map", {}) vmap = node.op.view_map
for i, r in enumerate(node.inputs): for i, r in enumerate(node.inputs):
# if thunk_py ran, and we still got # if thunk_py ran, and we still got
# this far, it means that the # this far, it means that the
......
...@@ -57,8 +57,8 @@ def alias_root(v): ...@@ -57,8 +57,8 @@ def alias_root(v):
""" """
if v.owner is None: if v.owner is None:
return v return v
vmap = getattr(v.owner.op, "view_map", {}) vmap = v.owner.op.view_map
dmap = getattr(v.owner.op, "destroy_map", {}) dmap = v.owner.op.destroy_map
outpos = v.owner.outputs.index(v) outpos = v.owner.outputs.index(v)
v_views = vmap.get(outpos, []) + dmap.get(outpos, []) v_views = vmap.get(outpos, []) + dmap.get(outpos, [])
if len(v_views) > 1: if len(v_views) > 1:
...@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset): ...@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset):
for cl, v_input_pos_to_cl in fgraph.clients[v]: for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output": if cl == "output":
continue continue
vmap = getattr(cl.op, "view_map", {}) vmap = cl.op.view_map
dmap = getattr(cl.op, "destroy_map", {}) dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()): for opos, iposlist in chain(vmap.items(), dmap.items()):
if v_input_pos_to_cl in iposlist: if v_input_pos_to_cl in iposlist:
if cl.outputs[opos] not in treeset: if cl.outputs[opos] not in treeset:
...@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping) fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if getattr(node.op, "destroy_map", None): if node.op.destroy_map:
if not accept_inplace: if not accept_inplace:
raise TypeError( raise TypeError(
"Graph must not contain inplace operations", node, node.op "Graph must not contain inplace operations", node, node.op
......
...@@ -962,8 +962,8 @@ class ProfileStats: ...@@ -962,8 +962,8 @@ class ProfileStats:
if ignore_dmap: if ignore_dmap:
dmap = None dmap = None
else: else:
dmap = getattr(node.op, "destroy_map", None) dmap = node.op.destroy_map
vmap = getattr(node.op, "view_map", None) vmap = node.op.view_map
val = nodes_mem[node] val = nodes_mem[node]
for v in val: for v in val:
...@@ -1125,8 +1125,8 @@ class ProfileStats: ...@@ -1125,8 +1125,8 @@ class ProfileStats:
mem_freed = 0 mem_freed = 0
max_storage = max_mem_count max_storage = max_mem_count
dmap = getattr(node.op, "destroy_map", None) dmap = node.op.destroy_map
vmap = getattr(node.op, "view_map", None) vmap = node.op.view_map
idx = 0 idx = 0
# Update the Python emulating dicts and add the # Update the Python emulating dicts and add the
...@@ -1426,9 +1426,9 @@ class ProfileStats: ...@@ -1426,9 +1426,9 @@ class ProfileStats:
items.sort(key=lambda a: a[1], reverse=True) items.sort(key=lambda a: a[1], reverse=True)
for idx, ((fgraph, node), node_outputs_size) in enumerate(items[:N]): for idx, ((fgraph, node), node_outputs_size) in enumerate(items[:N]):
code = ["c"] * len(node.outputs) code = ["c"] * len(node.outputs)
for out, inp in getattr(node.op, "destroy_map", {}).items(): for out, inp in node.op.destroy_map.items():
code[out] = "i" code[out] = "i"
for out, inp in getattr(node.op, "view_map", {}).items(): for out, inp in node.op.view_map.items():
code[out] = "v" code[out] = "v"
shapes = str(fct_shapes[fgraph][node]) shapes = str(fct_shapes[fgraph][node])
......
...@@ -186,11 +186,11 @@ class PyDotFormatter: ...@@ -186,11 +186,11 @@ class PyDotFormatter:
graph.add_node(pd_var) graph.add_node(pd_var)
edge_params = {} edge_params = {}
if hasattr(node.op, "view_map") and id in reduce( if node.op.view_map and id in reduce(
list.__add__, node.op.view_map.values(), [] list.__add__, node.op.view_map.values(), []
): ):
edge_params["color"] = self.node_colors["output"] edge_params["color"] = self.node_colors["output"]
elif hasattr(node.op, "destroy_map") and id in reduce( elif node.op.destroy_map and id in reduce(
list.__add__, node.op.destroy_map.values(), [] list.__add__, node.op.destroy_map.values(), []
): ):
edge_params["color"] = "red" edge_params["color"] = "red"
......
...@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa
for (app, idx) in fgraph.clients[protected_var]: for (app, idx) in fgraph.clients[protected_var]:
if app == "output": if app == "output":
continue continue
destroy_maps = getattr(app.op, "destroy_map", {}).values() destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var. # If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]: if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True return True
for var_idx in getattr(app.op, "view_map", {}).keys(): for var_idx in app.op.view_map.keys():
if idx in app.op.view_map[var_idx]: if idx in app.op.view_map[var_idx]:
# We need to recursivly check the destroy_map of all the # We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on. # outputs that we have a view_map on.
...@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa
- Allow sequence of view. - Allow sequence of view.
- But don't allow to destroy view - But don't allow to destroy view
""" """
dm = getattr(app.op, "destroy_map", None) dm = app.op.destroy_map
if not dm: if not dm:
return return
inputs = set( inputs = set(
...@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa
elif inp.owner: elif inp.owner:
app2 = inp.owner app2 = inp.owner
inp_idx2 = app2.outputs.index(inp) inp_idx2 = app2.outputs.index(inp)
v = getattr(app2.op, "view_map", {}) v = app2.op.view_map
d = getattr(app2.op, "destroy_map", {}) d = app2.op.destroy_map
if v: if v:
v = v.get(inp_idx2, []) v = v.get(inp_idx2, [])
if len(v) > 0: if len(v) > 0:
...@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
dmap = getattr(app.op, "destroy_map", None) dmap = app.op.destroy_map
vmap = getattr(app.op, "view_map", {}) vmap = app.op.view_map
if dmap: if dmap:
self.destroyers.add(app) self.destroyers.add(app)
if self.algo == "fast": if self.algo == "fast":
...@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa
for input in set(app.inputs): for input in set(app.inputs):
del self.clients[input][app] del self.clients[input][app]
if getattr(app.op, "destroy_map", OrderedDict()): if app.op.destroy_map:
self.destroyers.remove(app) self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct. # Note: leaving empty client dictionaries in the struct.
...@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa
# deleted on_detach(). # deleted on_detach().
# UPDATE self.view_i, self.view_o # UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, "view_map", OrderedDict()).items(): for o_idx, i_idx_list in app.op.view_map.items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs # destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa ...@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
# UPDATE self.view_i, self.view_o # UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, "view_map", OrderedDict()).items(): for o_idx, i_idx_list in app.op.view_map.items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs # destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
......
...@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject): ...@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject):
node : aesara.graph.basic.Apply node : aesara.graph.basic.Apply
""" """
if hasattr(node.op, "view_map") and not all( if node.op.view_map and not all(
isinstance(view, (list, tuple)) for view in node.op.view_map.values() isinstance(view, (list, tuple)) for view in node.op.view_map.values()
): ):
raise Exception( raise Exception(
f"Op '{node.op}' have a bad view map '{node.op.view_map}'," f"Op '{node.op}' have a bad view map '{node.op.view_map}',"
" the values must be tuples or lists." " the values must be tuples or lists."
) )
if hasattr(node.op, "destroy_map") and not all( if node.op.destroy_map and not all(
isinstance(destroy, (list, tuple)) isinstance(destroy, (list, tuple))
for destroy in node.op.destroy_map.values() for destroy in node.op.destroy_map.values()
): ):
......
...@@ -107,7 +107,7 @@ def compute_test_value(node: Apply): ...@@ -107,7 +107,7 @@ def compute_test_value(node: Apply):
# The original values should not be destroyed, so we copy the values of the # The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map` # inputs in `destroy_map`
destroyed_inputs_idx = set() destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None): if node.op.destroy_map:
for i_pos_list in node.op.destroy_map.values(): for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list) destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx: for inp_idx in destroyed_inputs_idx:
...@@ -167,6 +167,29 @@ class Op(MetaObject): ...@@ -167,6 +167,29 @@ class Op(MetaObject):
""" """
view_map: Dict[int, List[int]] = {}
"""
A ``dict`` that maps output indices to the input indices of which they are
a view.
Examples
========
view_map = {0: [1]} # first output is a view of second input
view_map = {1: [0]} # second output is a view of first input
"""
destroy_map: Dict[int, List[int]] = {}
"""
A ``dict`` that maps output indices to the input indices upon which they
operate in-place.
Examples
========
destroy_map = {0: [1]} # first output operates in-place on second input
destroy_map = {1: [0]} # second output operates in-place on first input
"""
def make_node(self, *inputs: Variable) -> Apply: def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs. """Construct an `Apply` node that represent the application of this operation to the given inputs.
......
...@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer):
[ [
i in flatten(c.op.destroy_map.values()) i in flatten(c.op.destroy_map.values())
for c, i in clients for c, i in clients
if c != "output" and hasattr(c.op, "destroy_map") if c != "output" and c.op.destroy_map
] ]
) )
> 1 > 1
......
...@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature): ...@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature):
node = out.owner node = out.owner
op = node.op op = node.op
out_idx = node.outputs.index(out) out_idx = node.outputs.index(out)
if hasattr(op, "destroy_map") and out_idx in op.destroy_map: if op.destroy_map and out_idx in op.destroy_map:
raise aesara.graph.fg.InconsistencyError( raise aesara.graph.fg.InconsistencyError(
"A function graph Feature has requested that outputs of the graph " "A function graph Feature has requested that outputs of the graph "
"be prevented from being the result of in-place " "be prevented from being the result of in-place "
......
...@@ -430,8 +430,8 @@ def raise_with_op( ...@@ -430,8 +430,8 @@ def raise_with_op(
total_size_inputs += sz total_size_inputs += sz
else: else:
# If it is a view, don't count it twice. # If it is a view, don't count it twice.
if getattr(k.owner.op, "view_map", None): vmap = k.owner.op.view_map
vmap = k.owner.op.view_map if vmap:
out_idx = k.owner.outputs.index(k) out_idx = k.owner.outputs.index(k)
data = storage_map[k][0] data = storage_map[k][0]
if out_idx in vmap: if out_idx in vmap:
...@@ -445,14 +445,14 @@ def raise_with_op( ...@@ -445,14 +445,14 @@ def raise_with_op(
# shouldn't be in the storage_map anymore # shouldn't be in the storage_map anymore
# except if there is a special flag used. So # except if there is a special flag used. So
# we still must check it. # we still must check it.
if getattr(k.owner.op, "destroy_map", None): dmap = k.owner.op.destroy_map
vmap = k.owner.op.destroy_map if dmap:
out_idx = k.owner.outputs.index(k) out_idx = k.owner.outputs.index(k)
data = storage_map[k][0] data = storage_map[k][0]
if out_idx in vmap: if out_idx in dmap:
assert len(vmap[out_idx]) == 1 assert len(dmap[out_idx]) == 1
input_data = storage_map[ input_data = storage_map[
k.owner.inputs[vmap[out_idx][0]] k.owner.inputs[dmap[out_idx][0]]
][0] ][0]
if k.type.may_share_memory(data, input_data): if k.type.may_share_memory(data, input_data):
total_size -= sz total_size -= sz
......
...@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for idx in range(len(order)): for idx in range(len(order)):
node = order[idx] node = order[idx]
dmap = getattr(node.op, "destroy_map", None) dmap = node.op.destroy_map
vmap = getattr(node.op, "view_map", None) vmap = node.op.view_map
idx_o = 0 idx_o = 0
for out in node.outputs: for out in node.outputs:
...@@ -574,9 +574,7 @@ class Stack(VM): ...@@ -574,9 +574,7 @@ class Stack(VM):
if ( if (
config.warn__vm_gc_bug config.warn__vm_gc_bug
and current_apply in apply_stack and current_apply in apply_stack
and getattr( and current_apply.op.destroy_map
current_apply.op, "destroy_map", False
)
): ):
warnings.warn( warnings.warn(
"There was a bug that existed in " "There was a bug that existed in "
......
...@@ -997,11 +997,11 @@ def pydotprint( ...@@ -997,11 +997,11 @@ def pydotprint(
param = {} param = {}
if label: if label:
param["label"] = label param["label"] = label
if hasattr(node.op, "view_map") and idx in reduce( if node.op.view_map and idx in reduce(
list.__add__, node.op.view_map.values(), [] list.__add__, node.op.view_map.values(), []
): ):
param["color"] = colorCodes["Output"] param["color"] = colorCodes["Output"]
elif hasattr(node.op, "destroy_map") and idx in reduce( elif node.op.destroy_map and idx in reduce(
list.__add__, node.op.destroy_map.values(), [] list.__add__, node.op.destroy_map.values(), []
): ):
param["color"] = "red" param["color"] = "red"
......
...@@ -794,8 +794,6 @@ class Scan(Op): ...@@ -794,8 +794,6 @@ class Scan(Op):
else: else:
name = "for" name = "for"
aux_txt = "%s" aux_txt = "%s"
if getattr(self, "destroy_map", None) is None:
self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted( if sorted(self.destroy_map.keys()) == sorted(
...@@ -1027,7 +1025,7 @@ class Scan(Op): ...@@ -1027,7 +1025,7 @@ class Scan(Op):
cython_inps_is_tensor = np.asarray(self.inps_is_tensor, dtype="int32") cython_inps_is_tensor = np.asarray(self.inps_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(self.outs_is_tensor, dtype="int32") cython_outs_is_tensor = np.asarray(self.outs_is_tensor, dtype="int32")
if hasattr(self, "destroy_map"): if self.destroy_map:
cython_destroy_map = [ cython_destroy_map = [
x in self.destroy_map for x in range(len(node.outputs)) x in self.destroy_map for x in range(len(node.outputs))
] ]
...@@ -1321,8 +1319,6 @@ class Scan(Op): ...@@ -1321,8 +1319,6 @@ class Scan(Op):
(-self.mintaps[idx]) % store_steps[idx] (-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot) for idx in range(self.n_outs + self.n_nit_sot)
] ]
if not getattr(self, "destroy_map", None):
self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs # 2.1 Create storage space for outputs
for idx in range(self.n_outs): for idx in range(self.n_outs):
if idx in self.destroy_map: if idx in self.destroy_map:
......
...@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# Get the indices of this client's inputs on which it # Get the indices of this client's inputs on which it
# operates inplace # operates inplace
if hasattr(client.op, "destroy_map"): if client.op.destroy_map:
# This flattens the content of destroy_map.values() # This flattens the content of destroy_map.values()
# which is a list of lists # which is a list of lists
inplace_inp_indices = sum(client.op.destroy_map.values(), []) inplace_inp_indices = sum(client.op.destroy_map.values(), [])
......
...@@ -20,7 +20,7 @@ def test_no_output_from_implace(): ...@@ -20,7 +20,7 @@ def test_no_output_from_implace():
# using a mode that does not include the optimization # using a mode that does not include the optimization
fct_no_opt = aesara.function([x, y], b, mode="FAST_RUN") fct_no_opt = aesara.function([x, y], b, mode="FAST_RUN")
op = fct_no_opt.maker.fgraph.outputs[0].owner.op op = fct_no_opt.maker.fgraph.outputs[0].owner.op
assert hasattr(op, "destroy_map") and 0 in op.destroy_map assert op.destroy_map and 0 in op.destroy_map
# Ensure that the elemwise op that produces the output is not inplace when # Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization # using a mode that includes the optimization
...@@ -29,7 +29,7 @@ def test_no_output_from_implace(): ...@@ -29,7 +29,7 @@ def test_no_output_from_implace():
fct_opt = aesara.function([x, y], b, mode=mode_opt) fct_opt = aesara.function([x, y], b, mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op op = fct_opt.maker.fgraph.outputs[0].owner.op
assert not hasattr(op, "destroy_map") or 0 not in op.destroy_map assert not op.destroy_map or 0 not in op.destroy_map
def test_including(): def test_including():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论