提交 857cda05 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Refactor graph/rewriting/utils.py

上级 0484e1e2
...@@ -44,32 +44,23 @@ def rewrite_graph( ...@@ -44,32 +44,23 @@ def rewrite_graph(
""" """
from pytensor.compile import optdb from pytensor.compile import optdb
return_fgraph = False
if isinstance(graph, FunctionGraph): if isinstance(graph, FunctionGraph):
fgraph = graph fgraph = graph
return_fgraph = True
else: else:
if isinstance(graph, list | tuple): outputs = [graph] if isinstance(graph, Variable) else graph
outputs = graph
else:
assert isinstance(graph, Variable)
outputs = [graph]
fgraph = FunctionGraph(outputs=outputs, clone=clone) fgraph = FunctionGraph(outputs=outputs, clone=clone)
query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs)) query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
_ = query_rewrites.rewrite(fgraph) query_rewrites.rewrite(fgraph)
if custom_rewrite: if custom_rewrite is not None:
custom_rewrite.rewrite(fgraph) custom_rewrite.rewrite(fgraph)
if return_fgraph: if isinstance(graph, FunctionGraph):
return fgraph return fgraph
else: if isinstance(graph, Variable):
if isinstance(graph, list | tuple):
return fgraph.outputs
else:
return fgraph.outputs[0] return fgraph.outputs[0]
return fgraph.outputs
def is_same_graph_with_merge( def is_same_graph_with_merge(
...@@ -90,14 +81,10 @@ def is_same_graph_with_merge( ...@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
""" """
from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.rewriting.basic import MergeOptimizer
if givens is None: givens = {} if givens is None else dict(givens)
givens = {}
givens = dict(givens)
# Copy variables since the MergeOptimizer will modify them. # Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy((var1, var2, givens)) *vars, givens = copy.deepcopy((var1, var2, givens))
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph. # Create FunctionGraph.
inputs = list(graph_inputs(vars)) inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will # The clone isn't needed as we did a deepcopy and we cloning will
...@@ -120,7 +107,6 @@ def is_same_graph_with_merge( ...@@ -120,7 +107,6 @@ def is_same_graph_with_merge(
# Comparing two single-Variable graphs: they are equal if they are # Comparing two single-Variable graphs: they are equal if they are
# the same Variable. # the same Variable.
return vars_replaced[0] == vars_replaced[1] return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2 return o1 is o2
...@@ -171,16 +157,16 @@ def is_same_graph( ...@@ -171,16 +157,16 @@ def is_same_graph(
====== ====== ====== ====== ====== ====== ====== ======
""" """
use_equal_computations = True givens = {} if givens is None else dict(givens)
if givens is None:
givens = {}
givens = dict(givens)
# Get result from the merge-based function. # Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens: if not givens:
rval2 = equal_computations(xs=[var1], ys=[var2])
assert rval1 == rval2
return rval1
# We need to build the `in_xs` and `in_ys` lists. To do this, we need # We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational # to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`. # graph of `var1` or `var2`.
...@@ -188,27 +174,19 @@ def is_same_graph( ...@@ -188,27 +174,19 @@ def is_same_graph(
# one of these graphs, and `replace_by` belongs to the other one. In # one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations` # other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it. # is probably not appropriate, so we do not call it.
ok = True use_equal_computations = True
in_xs = [] in_xs = []
in_ys = [] in_ys = []
# Compute the sets of all variables found in each computational graph. # Compute the sets of all variables found in each computational graph.
inputs_var1 = graph_inputs([var1]) inputs_var1 = graph_inputs([var1])
inputs_var2 = graph_inputs([var2]) inputs_var2 = graph_inputs([var2])
all_vars = [ all_vars1 = set(vars_between(inputs_var1, [var1]))
set(vars_between(v_i, v_o)) all_vars2 = set(vars_between(inputs_var2, [var2]))
for v_i, v_o in ((inputs_var1, [var1]), (inputs_var2, [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
for to_replace, replace_by in givens.items(): for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it # Map a substitution variable to the computational graphs it
# belongs to. # belongs to.
inside = { inside = {v: [v in all_vars1, v in all_vars2] for v in (to_replace, replace_by)}
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
}
if ( if (
inside[to_replace][0] inside[to_replace][0]
and not inside[to_replace][1] and not inside[to_replace][1]
...@@ -228,14 +206,9 @@ def is_same_graph( ...@@ -228,14 +206,9 @@ def is_same_graph(
in_xs.append(replace_by) in_xs.append(replace_by)
in_ys.append(to_replace) in_ys.append(to_replace)
else: else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False use_equal_computations = False
else: break
in_xs = None
in_ys = None
if use_equal_computations: if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys) rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1 assert rval2 == rval1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论