提交 be0e13a9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reduce number of access to self.clients in FunctionGraph

To speedup hot rewrite loops
上级 36c55f5b
...@@ -232,13 +232,13 @@ class FunctionGraph(MetaObject): ...@@ -232,13 +232,13 @@ class FunctionGraph(MetaObject):
entry for `var` in `self.clients`. entry for `var` in `self.clients`.
""" """
clients = self.clients
removal_stack = [(var, client_to_remove)] removal_stack = [(var, client_to_remove)]
while removal_stack: while removal_stack:
var, client_to_remove = removal_stack.pop() var, client_to_remove = removal_stack.pop()
try: try:
var_clients = self.clients[var] var_clients = clients[var]
var_clients.remove(client_to_remove) var_clients.remove(client_to_remove)
except ValueError: except ValueError:
# In this case, the original `var` could've been removed from # In this case, the original `var` could've been removed from
...@@ -256,9 +256,7 @@ class FunctionGraph(MetaObject): ...@@ -256,9 +256,7 @@ class FunctionGraph(MetaObject):
self.variables.remove(var) self.variables.remove(var)
else: else:
apply_node = var.owner apply_node = var.owner
if not any( if not any(clients[output] for output in apply_node.outputs):
output for output in apply_node.outputs if self.clients[output]
):
# The `Apply` node is not used and is not an output, so we # The `Apply` node is not used and is not an output, so we
# remove it and its outputs # remove it and its outputs
if not hasattr(apply_node.tag, "removed_by"): if not hasattr(apply_node.tag, "removed_by"):
...@@ -276,7 +274,7 @@ class FunctionGraph(MetaObject): ...@@ -276,7 +274,7 @@ class FunctionGraph(MetaObject):
removal_stack.append((in_var, (apply_node, i))) removal_stack.append((in_var, (apply_node, i)))
if remove_if_empty: if remove_if_empty:
del self.clients[var] del clients[var]
def import_var( def import_var(
self, var: Variable, reason: str | None = None, import_missing: bool = False self, var: Variable, reason: str | None = None, import_missing: bool = False
...@@ -563,10 +561,11 @@ class FunctionGraph(MetaObject): ...@@ -563,10 +561,11 @@ class FunctionGraph(MetaObject):
node.tag.removed_by.append(str(reason)) node.tag.removed_by.append(str(reason))
# Remove the outputs of the node (i.e. everything "below" it) # Remove the outputs of the node (i.e. everything "below" it)
clients = self.clients
for out in node.outputs: for out in node.outputs:
self.variables.remove(out) self.variables.remove(out)
out_clients = self.clients.get(out, ()) out_clients = clients.get(out, ())
while out_clients: while out_clients:
out_client, out_idx = out_clients.pop() out_client, out_idx = out_clients.pop()
...@@ -590,13 +589,12 @@ class FunctionGraph(MetaObject): ...@@ -590,13 +589,12 @@ class FunctionGraph(MetaObject):
assert isinstance(out_client, Apply) assert isinstance(out_client, Apply)
self.remove_node(out_client, reason=reason) self.remove_node(out_client, reason=reason)
if out in self.clients: clients.pop(out, None)
del self.clients[out]
# Remove all the arrows pointing to this `node`, and any orphaned # Remove all the arrows pointing to this `node`, and any orphaned
# variables created by removing those arrows # variables created by removing those arrows
for inp_idx, inp in enumerate(node.inputs): for inp_idx, inp in enumerate(node.inputs):
inp_clients: list[ClientType] = self.clients.get(inp, []) inp_clients: list[ClientType] = clients.get(inp, [])
arrow = (node, inp_idx) arrow = (node, inp_idx)
...@@ -810,12 +808,13 @@ class FunctionGraph(MetaObject): ...@@ -810,12 +808,13 @@ class FunctionGraph(MetaObject):
raise Exception( raise Exception(
f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}" f"The following nodes are inappropriately cached:\nmissing: {nodes_missing}\nin excess: {nodes_excess}"
) )
clients = self.clients
for node in nodes: for node in nodes:
for i, variable in enumerate(node.inputs): for i, variable in enumerate(node.inputs):
clients = self.clients[variable] var_clients = clients[variable]
if (node, i) not in clients: if (node, i) not in var_clients:
raise Exception( raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}" f"Inconsistent clients list {(node, i)} in {var_clients}"
) )
variables = set(vars_between(self.inputs, self.outputs)) variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
...@@ -831,7 +830,7 @@ class FunctionGraph(MetaObject): ...@@ -831,7 +830,7 @@ class FunctionGraph(MetaObject):
and not isinstance(variable, AtomicVariable) and not isinstance(variable, AtomicVariable)
): ):
raise Exception(f"Undeclared input: {variable}") raise Exception(f"Undeclared input: {variable}")
for cl_node, i in self.clients[variable]: for cl_node, i in clients[variable]:
if cl_node == "output": if cl_node == "output":
if self.outputs[i] is not variable: if self.outputs[i] is not variable:
raise Exception( raise Exception(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论