提交 aa7e4d6b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Speedup FunctionGraph methods

上级 066307f0
......@@ -24,7 +24,6 @@ from pytensor.graph.traversal import (
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
from pytensor.misc.ordered_set import OrderedSet
ClientType = tuple[Apply, int]
......@@ -133,7 +132,6 @@ class FunctionGraph(MetaObject):
features = []
self._features: list[Feature] = []
# All apply nodes in the subgraph defined by inputs and
# outputs are cached in this field
self.apply_nodes: set[Apply] = set()
......@@ -161,7 +159,8 @@ class FunctionGraph(MetaObject):
"input's owner or use graph.clone."
)
self.add_input(in_var, check=False)
self.inputs.append(in_var)
self.clients.setdefault(in_var, [])
for output in outputs:
self.add_output(output, reason="init")
......@@ -189,16 +188,6 @@ class FunctionGraph(MetaObject):
return
self.inputs.append(var)
self.setup_var(var)
def setup_var(self, var: Variable) -> None:
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : pytensor.graph.basic.Variable
"""
self.clients.setdefault(var, [])
def get_clients(self, var: Variable) -> list[ClientType]:
......@@ -322,10 +311,11 @@ class FunctionGraph(MetaObject):
"""
# Imports the owners of the variables
if var.owner and var.owner not in self.apply_nodes:
self.import_node(var.owner, reason=reason, import_missing=import_missing)
apply = var.owner
if apply is not None and apply not in self.apply_nodes:
self.import_node(apply, reason=reason, import_missing=import_missing)
elif (
var.owner is None
apply is None
and not isinstance(var, AtomicVariable)
and var not in self.inputs
):
......@@ -336,10 +326,11 @@ class FunctionGraph(MetaObject):
f"Computation graph contains a NaN. {var.type.why_null}"
)
if import_missing:
self.add_input(var)
self.inputs.append(var)
self.clients.setdefault(var, [])
else:
raise MissingInputError(f"Undeclared input: {var}", variable=var)
self.setup_var(var)
self.clients.setdefault(var, [])
self.variables.add(var)
def import_node(
......@@ -356,29 +347,29 @@ class FunctionGraph(MetaObject):
apply_node : Apply
The node to be imported.
check : bool
Check that the inputs for the imported nodes are also present in
the `FunctionGraph`.
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
import_missing : bool
Add missing inputs instead of raising an exception.
"""
# We import the nodes in topological order. We only are interested in
# new nodes, so we use all variables we know of as if they were the
# input set. (The functions in the graph module only use the input set
# to know where to stop going down.)
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
self_variables = self.variables
self_clients = self.clients
self_apply_nodes = self.apply_nodes
self_inputs = self.inputs
for node in toposort(apply_node.outputs, blockers=self_variables):
if check:
for node in new_nodes:
for var in node.inputs:
if (
var.owner is None
and not isinstance(var, AtomicVariable)
and var not in self.inputs
and var not in self_inputs
):
if import_missing:
self.add_input(var)
self_inputs.append(var)
self_clients.setdefault(var, [])
else:
error_msg = (
f"Input {node.inputs.index(var)} ({var})"
......@@ -390,20 +381,20 @@ class FunctionGraph(MetaObject):
)
raise MissingInputError(error_msg, variable=var)
for node in new_nodes:
assert node not in self.apply_nodes
self.apply_nodes.add(node)
if not hasattr(node.tag, "imported_by"):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
self_apply_nodes.add(node)
tag = node.tag
if not hasattr(tag, "imported_by"):
tag.imported_by = [str(reason)]
else:
tag.imported_by.append(str(reason))
for output in node.outputs:
self.setup_var(output)
self.variables.add(output)
for i, input in enumerate(node.inputs):
if input not in self.variables:
self.setup_var(input)
self.variables.add(input)
self.add_client(input, (node, i))
self_clients.setdefault(output, [])
self_variables.add(output)
for i, inp in enumerate(node.inputs):
if inp not in self_variables:
self_clients.setdefault(inp, [])
self_variables.add(inp)
self_clients[inp].append((node, i))
self.execute_callbacks("on_import", node, reason)
def change_node_input(
......@@ -457,7 +448,7 @@ class FunctionGraph(MetaObject):
self.outputs[node.op.idx] = new_var
self.import_var(new_var, reason=reason, import_missing=import_missing)
self.add_client(new_var, (node, i))
self.clients[new_var].append((node, i))
self.remove_client(r, (node, i), reason=reason)
# Precondition: the substitution is semantically valid However it may
# introduce cycles to the graph, in which case the transaction will be
......@@ -756,10 +747,6 @@ class FunctionGraph(MetaObject):
:meth:`FunctionGraph.orderings`.
"""
if len(self.apply_nodes) < 2:
# No sorting is necessary
return list(self.apply_nodes)
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
def orderings(self) -> dict[Apply, list[Apply]]:
......@@ -779,29 +766,17 @@ class FunctionGraph(MetaObject):
take care of computing the dependencies by itself.
"""
assert isinstance(self._features, list)
all_orderings: list[dict] = []
for feature in self._features:
if hasattr(feature, "orderings"):
orderings = feature.orderings(self)
if not isinstance(orderings, dict):
raise TypeError(
"Non-deterministic return value from "
+ str(feature.orderings)
+ ". Nondeterministic object is "
+ str(orderings)
)
if len(orderings) > 0:
all_orderings.append(orderings)
for node, prereqs in orderings.items():
if not isinstance(prereqs, list | OrderedSet):
raise TypeError(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic."
all_orderings: list[dict] = [
orderings
for feature in self._features
if (
hasattr(feature, "orderings") and (orderings := feature.orderings(self))
)
if len(all_orderings) == 1:
]
if not all_orderings:
return {}
elif len(all_orderings) == 1:
# If there is only 1 ordering, we reuse it directly.
return all_orderings[0].copy()
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论