提交 b0a1d33c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add view_map and destroy_map variables and docstrings to Op

上级 cae78759
......@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError):
print(" node:", self.node, file=sio)
print(" perform:", self.perform, file=sio)
print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio)
print(" destroy_map:", getattr(self.node.op, "destroy_map", {}), file=sio)
print(" destroy_map:", self.node.op.destroy_map, file=sio)
print(" changed input idx:", self.idx, file=sio)
print(" changed input type:", self.node.inputs[self.idx].type, file=sio)
print(" repr (old val):", repr(self.old_val), file=sio)
......@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError):
print(" node:", self.node, file=sio)
print(" node.inputs:", [(str(i), id(i)) for i in self.node.inputs], file=sio)
print(" node.outputs:", [(str(i), id(i)) for i in self.node.outputs], file=sio)
print(" view_map:", getattr(self.node.op, "view_map", {}), file=sio)
print(" destroy_map:", getattr(self.node.op, "destroy_map", {}), file=sio)
print(" view_map:", self.node.op.view_map, file=sio)
print(" destroy_map:", self.node.op.destroy_map, file=sio)
print(" aliased output:", self.output_idx, file=sio)
print(" aliased output storage:", self.out_storage, file=sio)
if self.in_alias_idx:
......@@ -554,12 +554,12 @@ def debugprint(
r_name = ""
if print_destroy_map:
destroy_map_str = str(getattr(r.owner.op, "destroy_map", ""))
destroy_map_str = str(r.owner.op.destroy_map)
else:
destroy_map_str = ""
if print_view_map:
view_map_str = str(getattr(r.owner.op, "view_map", ""))
view_map_str = str(r.owner.op.view_map)
else:
view_map_str = ""
if destroy_map_str and destroy_map_str != "{}":
......@@ -742,13 +742,13 @@ def _check_inputs(
"""
destroyed_idx_list = []
destroy_map = getattr(node.op, "destroy_map", {})
destroy_map = node.op.destroy_map
for o_pos, i_pos_list in destroy_map.items():
destroyed_idx_list.extend(i_pos_list)
destroyed_res_list = [node.inputs[i] for i in destroyed_idx_list]
actually_inplace_outputs = []
dmap = getattr(node.op, "destroy_map", {})
dmap = node.op.destroy_map
for oo, ii in dmap.items():
var = node.outputs[oo]
out_var = storage_map[var][0]
......@@ -769,7 +769,7 @@ def _check_inputs(
f"as destroyed was not changed for node '{node}'"
)
vmap = getattr(node.op, "view_map", {})
vmap = node.op.view_map
for oo, ii in vmap.items():
var = node.outputs[oo]
out_var = storage_map[var][0]
......@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map):
outstorage = storage_map[onode][0]
# first find out which input it aliases
view_map = getattr(node.op, "view_map", {})
destroy_map = getattr(node.op, "destroy_map", {})
view_map = node.op.view_map
destroy_map = node.op.destroy_map
# In theory, aesara's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in
......@@ -1395,8 +1395,8 @@ def _check_preallocated_output(
# Set of inputs that are marked as destroyed or viewed
aliased_inputs = set()
dmap = getattr(node.op, "destroy_map", {})
vmap = getattr(node.op, "view_map", {})
dmap = node.op.destroy_map
vmap = node.op.view_map
for i, r in enumerate(node.inputs):
if any(i in v for v in chain(dmap.values(), vmap.values())):
aliased_inputs.add(r)
......@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker):
clobber = True
if thunk_py:
dmap = getattr(node.op, "destroy_map", {})
vmap = getattr(node.op, "view_map", {})
dmap = node.op.destroy_map
vmap = node.op.view_map
for i, r in enumerate(node.inputs):
# if thunk_py ran, and we still got
# this far, it means that the
......
......@@ -57,8 +57,8 @@ def alias_root(v):
"""
if v.owner is None:
return v
vmap = getattr(v.owner.op, "view_map", {})
dmap = getattr(v.owner.op, "destroy_map", {})
vmap = v.owner.op.view_map
dmap = v.owner.op.destroy_map
outpos = v.owner.outputs.index(v)
v_views = vmap.get(outpos, []) + dmap.get(outpos, [])
if len(v_views) > 1:
......@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset):
for cl, v_input_pos_to_cl in fgraph.clients[v]:
if cl == "output":
continue
vmap = getattr(cl.op, "view_map", {})
dmap = getattr(cl.op, "destroy_map", {})
vmap = cl.op.view_map
dmap = cl.op.destroy_map
for opos, iposlist in chain(vmap.items(), dmap.items()):
if v_input_pos_to_cl in iposlist:
if cl.outputs[opos] not in treeset:
......@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
for node in fgraph.apply_nodes:
if getattr(node.op, "destroy_map", None):
if node.op.destroy_map:
if not accept_inplace:
raise TypeError(
"Graph must not contain inplace operations", node, node.op
......
......@@ -962,8 +962,8 @@ class ProfileStats:
if ignore_dmap:
dmap = None
else:
dmap = getattr(node.op, "destroy_map", None)
vmap = getattr(node.op, "view_map", None)
dmap = node.op.destroy_map
vmap = node.op.view_map
val = nodes_mem[node]
for v in val:
......@@ -1125,8 +1125,8 @@ class ProfileStats:
mem_freed = 0
max_storage = max_mem_count
dmap = getattr(node.op, "destroy_map", None)
vmap = getattr(node.op, "view_map", None)
dmap = node.op.destroy_map
vmap = node.op.view_map
idx = 0
# Update the Python emulating dicts and add the
......@@ -1426,9 +1426,9 @@ class ProfileStats:
items.sort(key=lambda a: a[1], reverse=True)
for idx, ((fgraph, node), node_outputs_size) in enumerate(items[:N]):
code = ["c"] * len(node.outputs)
for out, inp in getattr(node.op, "destroy_map", {}).items():
for out, inp in node.op.destroy_map.items():
code[out] = "i"
for out, inp in getattr(node.op, "view_map", {}).items():
for out, inp in node.op.view_map.items():
code[out] = "v"
shapes = str(fct_shapes[fgraph][node])
......
......@@ -186,11 +186,11 @@ class PyDotFormatter:
graph.add_node(pd_var)
edge_params = {}
if hasattr(node.op, "view_map") and id in reduce(
if node.op.view_map and id in reduce(
list.__add__, node.op.view_map.values(), []
):
edge_params["color"] = self.node_colors["output"]
elif hasattr(node.op, "destroy_map") and id in reduce(
elif node.op.destroy_map and id in reduce(
list.__add__, node.op.destroy_map.values(), []
):
edge_params["color"] = "red"
......
......@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa
for (app, idx) in fgraph.clients[protected_var]:
if app == "output":
continue
destroy_maps = getattr(app.op, "destroy_map", {}).values()
destroy_maps = app.op.destroy_map.values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var_idx in getattr(app.op, "view_map", {}).keys():
for var_idx in app.op.view_map.keys():
if idx in app.op.view_map[var_idx]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
......@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa
- Allow sequence of view.
- But don't allow to destroy view
"""
dm = getattr(app.op, "destroy_map", None)
dm = app.op.destroy_map
if not dm:
return
inputs = set(
......@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa
elif inp.owner:
app2 = inp.owner
inp_idx2 = app2.outputs.index(inp)
v = getattr(app2.op, "view_map", {})
d = getattr(app2.op, "destroy_map", {})
v = app2.op.view_map
d = app2.op.destroy_map
if v:
v = v.get(inp_idx2, [])
if len(v) > 0:
......@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
dmap = getattr(app.op, "destroy_map", None)
vmap = getattr(app.op, "view_map", {})
dmap = app.op.destroy_map
vmap = app.op.view_map
if dmap:
self.destroyers.add(app)
if self.algo == "fast":
......@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa
for input in set(app.inputs):
del self.clients[input][app]
if getattr(app.op, "destroy_map", OrderedDict()):
if app.op.destroy_map:
self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct.
......@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa
# deleted on_detach().
# UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, "view_map", OrderedDict()).items():
for o_idx, i_idx_list in app.op.view_map.items():
if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa
self.clients[new_r][app] += 1
# UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, "view_map", OrderedDict()).items():
for o_idx, i_idx_list in app.op.view_map.items():
if len(i_idx_list) > 1:
# destroying this output invalidates multiple inputs
raise NotImplementedError()
......
......@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject):
node : aesara.graph.basic.Apply
"""
if hasattr(node.op, "view_map") and not all(
if node.op.view_map and not all(
isinstance(view, (list, tuple)) for view in node.op.view_map.values()
):
raise Exception(
f"Op '{node.op}' have a bad view map '{node.op.view_map}',"
" the values must be tuples or lists."
)
if hasattr(node.op, "destroy_map") and not all(
if node.op.destroy_map and not all(
isinstance(destroy, (list, tuple))
for destroy in node.op.destroy_map.values()
):
......
......@@ -107,7 +107,7 @@ def compute_test_value(node: Apply):
# The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map`
destroyed_inputs_idx = set()
if getattr(node.op, "destroy_map", None):
if node.op.destroy_map:
for i_pos_list in node.op.destroy_map.values():
destroyed_inputs_idx.update(i_pos_list)
for inp_idx in destroyed_inputs_idx:
......@@ -167,6 +167,29 @@ class Op(MetaObject):
"""
view_map: Dict[int, List[int]] = {}
"""
A ``dict`` that maps output indices to the input indices of which they are
a view.
Examples
========
view_map = {0: [1]} # first output is a view of second input
view_map = {1: [0]} # second output is a view of first input
"""
destroy_map: Dict[int, List[int]] = {}
"""
A ``dict`` that maps output indices to the input indices upon which they
operate in-place.
Examples
========
destroy_map = {0: [1]} # first output operates in-place on second input
destroy_map = {1: [0]} # second output operates in-place on first input
"""
def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
......
......@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer):
[
i in flatten(c.op.destroy_map.values())
for c, i in clients
if c != "output" and hasattr(c.op, "destroy_map")
if c != "output" and c.op.destroy_map
]
)
> 1
......
......@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature):
node = out.owner
op = node.op
out_idx = node.outputs.index(out)
if hasattr(op, "destroy_map") and out_idx in op.destroy_map:
if op.destroy_map and out_idx in op.destroy_map:
raise aesara.graph.fg.InconsistencyError(
"A function graph Feature has requested that outputs of the graph "
"be prevented from being the result of in-place "
......
......@@ -430,8 +430,8 @@ def raise_with_op(
total_size_inputs += sz
else:
# If it is a view, don't count it twice.
if getattr(k.owner.op, "view_map", None):
vmap = k.owner.op.view_map
vmap = k.owner.op.view_map
if vmap:
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
......@@ -445,14 +445,14 @@ def raise_with_op(
# shouldn't be in the storage_map anymore
# except if there is a special flag used. So
# we still must check it.
if getattr(k.owner.op, "destroy_map", None):
vmap = k.owner.op.destroy_map
dmap = k.owner.op.destroy_map
if dmap:
out_idx = k.owner.outputs.index(k)
data = storage_map[k][0]
if out_idx in vmap:
assert len(vmap[out_idx]) == 1
if out_idx in dmap:
assert len(dmap[out_idx]) == 1
input_data = storage_map[
k.owner.inputs[vmap[out_idx][0]]
k.owner.inputs[dmap[out_idx][0]]
][0]
if k.type.may_share_memory(data, input_data):
total_size -= sz
......
......@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for idx in range(len(order)):
node = order[idx]
dmap = getattr(node.op, "destroy_map", None)
vmap = getattr(node.op, "view_map", None)
dmap = node.op.destroy_map
vmap = node.op.view_map
idx_o = 0
for out in node.outputs:
......@@ -574,9 +574,7 @@ class Stack(VM):
if (
config.warn__vm_gc_bug
and current_apply in apply_stack
and getattr(
current_apply.op, "destroy_map", False
)
and current_apply.op.destroy_map
):
warnings.warn(
"There was a bug that existed in "
......
......@@ -997,11 +997,11 @@ def pydotprint(
param = {}
if label:
param["label"] = label
if hasattr(node.op, "view_map") and idx in reduce(
if node.op.view_map and idx in reduce(
list.__add__, node.op.view_map.values(), []
):
param["color"] = colorCodes["Output"]
elif hasattr(node.op, "destroy_map") and idx in reduce(
elif node.op.destroy_map and idx in reduce(
list.__add__, node.op.destroy_map.values(), []
):
param["color"] = "red"
......
......@@ -794,8 +794,6 @@ class Scan(Op):
else:
name = "for"
aux_txt = "%s"
if getattr(self, "destroy_map", None) is None:
self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
......@@ -1027,7 +1025,7 @@ class Scan(Op):
cython_inps_is_tensor = np.asarray(self.inps_is_tensor, dtype="int32")
cython_outs_is_tensor = np.asarray(self.outs_is_tensor, dtype="int32")
if hasattr(self, "destroy_map"):
if self.destroy_map:
cython_destroy_map = [
x in self.destroy_map for x in range(len(node.outputs))
]
......@@ -1321,8 +1319,6 @@ class Scan(Op):
(-self.mintaps[idx]) % store_steps[idx]
for idx in range(self.n_outs + self.n_nit_sot)
]
if not getattr(self, "destroy_map", None):
self.destroy_map = OrderedDict()
# 2.1 Create storage space for outputs
for idx in range(self.n_outs):
if idx in self.destroy_map:
......
......@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# Get the indices of this client's inputs on which it
# operates inplace
if hasattr(client.op, "destroy_map"):
if client.op.destroy_map:
# This flattens the content of destroy_map.values()
# which is a list of lists
inplace_inp_indices = sum(client.op.destroy_map.values(), [])
......
......@@ -20,7 +20,7 @@ def test_no_output_from_implace():
# using a mode that does not include the optimization
fct_no_opt = aesara.function([x, y], b, mode="FAST_RUN")
op = fct_no_opt.maker.fgraph.outputs[0].owner.op
assert hasattr(op, "destroy_map") and 0 in op.destroy_map
assert op.destroy_map and 0 in op.destroy_map
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
......@@ -29,7 +29,7 @@ def test_no_output_from_implace():
fct_opt = aesara.function([x, y], b, mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op
assert not hasattr(op, "destroy_map") or 0 not in op.destroy_map
assert not op.destroy_map or 0 not in op.destroy_map
def test_including():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论