提交 be0e13a9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reduce number of access to self.clients in FunctionGraph

To speedup hot rewrite loops
上级 36c55f5b
......@@ -232,13 +232,13 @@ class FunctionGraph(MetaObject):
entry for `var` in `self.clients`.
"""
clients = self.clients
removal_stack = [(var, client_to_remove)]
while removal_stack:
var, client_to_remove = removal_stack.pop()
try:
var_clients = self.clients[var]
var_clients = clients[var]
var_clients.remove(client_to_remove)
except ValueError:
# In this case, the original `var` could've been removed from
......@@ -256,9 +256,7 @@ class FunctionGraph(MetaObject):
self.variables.remove(var)
else:
apply_node = var.owner
if not any(
output for output in apply_node.outputs if self.clients[output]
):
if not any(clients[output] for output in apply_node.outputs):
# The `Apply` node is not used and is not an output, so we
# remove it and its outputs
if not hasattr(apply_node.tag, "removed_by"):
......@@ -276,7 +274,7 @@ class FunctionGraph(MetaObject):
removal_stack.append((in_var, (apply_node, i)))
if remove_if_empty:
del self.clients[var]
del clients[var]
def import_var(
self, var: Variable, reason: str | None = None, import_missing: bool = False
......@@ -563,10 +561,11 @@ class FunctionGraph(MetaObject):
node.tag.removed_by.append(str(reason))
# Remove the outputs of the node (i.e. everything "below" it)
clients = self.clients
for out in node.outputs:
self.variables.remove(out)
out_clients = self.clients.get(out, ())
out_clients = clients.get(out, ())
while out_clients:
out_client, out_idx = out_clients.pop()
......@@ -590,13 +589,12 @@ class FunctionGraph(MetaObject):
assert isinstance(out_client, Apply)
self.remove_node(out_client, reason=reason)
if out in self.clients:
del self.clients[out]
clients.pop(out, None)
# Remove all the arrows pointing to this `node`, and any orphaned
# variables created by removing those arrows
for inp_idx, inp in enumerate(node.inputs):
inp_clients: list[ClientType] = self.clients.get(inp, [])
inp_clients: list[ClientType] = clients.get(inp, [])
arrow = (node, inp_idx)
......@@ -810,12 +808,13 @@ class FunctionGraph(MetaObject):
raise Exception(
f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}"
)
clients = self.clients
for node in nodes:
for i, variable in enumerate(node.inputs):
clients = self.clients[variable]
if (node, i) not in clients:
var_clients = clients[variable]
if (node, i) not in var_clients:
raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}"
f"Inconsistent clients list {(node, i)} in {var_clients}"
)
variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables:
......@@ -831,7 +830,7 @@ class FunctionGraph(MetaObject):
and not isinstance(variable, AtomicVariable)
):
raise Exception(f"Undeclared input: {variable}")
for cl_node, i in self.clients[variable]:
for cl_node, i in clients[variable]:
if cl_node == "output":
if self.outputs[i] is not variable:
raise Exception(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论