提交 857cda05 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Refactor graph/rewriting/utils.py

上级 0484e1e2
......@@ -44,32 +44,23 @@ def rewrite_graph(
"""
from pytensor.compile import optdb
return_fgraph = False
if isinstance(graph, FunctionGraph):
fgraph = graph
return_fgraph = True
else:
if isinstance(graph, list | tuple):
outputs = graph
else:
assert isinstance(graph, Variable)
outputs = [graph]
outputs = [graph] if isinstance(graph, Variable) else graph
fgraph = FunctionGraph(outputs=outputs, clone=clone)
query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = query_rewrites.rewrite(fgraph)
query_rewrites.rewrite(fgraph)
if custom_rewrite:
if custom_rewrite is not None:
custom_rewrite.rewrite(fgraph)
if return_fgraph:
if isinstance(graph, FunctionGraph):
return fgraph
else:
if isinstance(graph, list | tuple):
return fgraph.outputs
else:
if isinstance(graph, Variable):
return fgraph.outputs[0]
return fgraph.outputs
def is_same_graph_with_merge(
......@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
"""
from pytensor.graph.rewriting.basic import MergeOptimizer
if givens is None:
givens = {}
givens = dict(givens)
givens = {} if givens is None else dict(givens)
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy((var1, var2, givens))
vars = copied[0:2]
givens = copied[2]
*vars, givens = copy.deepcopy((var1, var2, givens))
# Create FunctionGraph.
inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
......@@ -120,7 +107,6 @@ def is_same_graph_with_merge(
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
......@@ -171,16 +157,16 @@ def is_same_graph(
====== ====== ====== ======
"""
use_equal_computations = True
if givens is None:
givens = {}
givens = dict(givens)
givens = {} if givens is None else dict(givens)
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens:
if not givens:
rval2 = equal_computations(xs=[var1], ys=[var2])
assert rval1 == rval2
return rval1
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
......@@ -188,27 +174,19 @@ def is_same_graph(
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok = True
use_equal_computations = True
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var1 = graph_inputs([var1])
inputs_var2 = graph_inputs([var2])
all_vars = [
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var1, [var1]), (inputs_var2, [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
all_vars1 = set(vars_between(inputs_var1, [var1]))
all_vars2 = set(vars_between(inputs_var2, [var2]))
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = {
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
}
inside = {v: [v in all_vars1, v in all_vars2] for v in (to_replace, replace_by)}
if (
inside[to_replace][0]
and not inside[to_replace][1]
......@@ -228,14 +206,9 @@ def is_same_graph(
in_xs.append(replace_by)
in_ys.append(to_replace)
else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False
else:
in_xs = None
in_ys = None
break
if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论