提交 6a0c00a7 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Type two functions in graph/rewriting/utils.py

上级 56c30e03
......@@ -986,7 +986,7 @@ def explicit_graph_inputs(
def vars_between(
ins: Collection[Variable], outs: Iterable[Variable]
ins: Iterable[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
r"""Extract the `Variable`\s within the sub-graph between input and output nodes.
......@@ -1006,6 +1006,8 @@ def vars_between(
"""
ins = set(ins)
def expand(r: Variable) -> Iterable[Variable] | None:
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
......
......@@ -72,7 +72,16 @@ def rewrite_graph(
return fgraph.outputs[0]
def is_same_graph_with_merge(var1, var2, givens=None):
def is_same_graph_with_merge(
var1: Variable,
var2: Variable,
givens: (
list[tuple[Variable, Variable]]
| tuple[tuple[Variable, Variable], ...]
| dict[Variable, Variable]
| None
) = None,
) -> bool:
"""
Merge-based implementation of `pytensor.graph.basic.is_same_graph`.
......@@ -83,8 +92,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
if givens is None:
givens = {}
givens = dict(givens)
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
copied = copy.deepcopy((var1, var2, givens))
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
......@@ -113,7 +124,16 @@ def is_same_graph_with_merge(var1, var2, givens=None):
return o1 is o2
def is_same_graph(var1, var2, givens=None):
def is_same_graph(
var1: Variable,
var2: Variable,
givens: (
list[tuple[Variable, Variable]]
| tuple[tuple[Variable, Variable], ...]
| dict[Variable, Variable]
| None
) = None,
) -> bool:
"""
Return True iff Variables `var1` and `var2` perform the same computation.
......@@ -155,8 +175,6 @@ def is_same_graph(var1, var2, givens=None):
if givens is None:
givens = {}
if not isinstance(givens, dict):
givens = dict(givens)
# Get result from the merge-based function.
......@@ -174,10 +192,11 @@ def is_same_graph(var1, var2, givens=None):
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(graph_inputs, ([var1], [var2])))
inputs_var1 = graph_inputs([var1])
inputs_var2 = graph_inputs([var2])
all_vars = [
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
for v_i, v_o in ((inputs_var1, [var1]), (inputs_var2, [var2]))
]
def in_var(x, k):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论