提交 ea070b61 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move is_same_graph_with_merge and is_same_graph to aesara.graph.opt_utils

上级 fa74d7e3
......@@ -30,9 +30,10 @@ from aesara.graph.basic import (
vars_between,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes, is_same_graph
from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import ops_with_inner_function
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.utils import get_variable_trace_string
from aesara.link.basic import Container
from aesara.link.utils import raise_with_op
......
import copy
import inspect
import sys
import time
......@@ -10,13 +9,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import (
Variable,
equal_computations,
graph_inputs,
io_toposort,
vars_between,
)
from aesara.graph.basic import Variable, io_toposort
class AlreadyThere(Exception):
......@@ -819,154 +812,3 @@ class NoOutputFromInplace(Feature):
f"operations. This has prevented the output {out} from "
"being computed by modifying another variable in-place."
)
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `aesara.graph.basic.is_same_graph`.
See help on `aesara.graph.basic.is_same_graph` for additional documentation.
"""
from aesara.graph.opt import MergeOptimizer
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = aesara.graph.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in fgraph.outputs]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def is_same_graph(var1, var2, givens=None):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
By 'performing the same computation', we mean that they must share the same
graph, so that for instance this function will return False when comparing
(x * (y * z)) with ((x * y) * z).
The current implementation is not efficient since, when possible, it
verifies equality by calling two different functions that are expected to
return the same output. The goal is to verify this assumption, to
eventually get rid of one of them in the future.
Parameters
----------
var1
The first Variable to compare.
var2
The second Variable to compare.
givens
Similar to the `givens` argument of `aesara.function`, it can be used
to perform substitutions in the computational graph of `var1` and
`var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable
is present in both.
Examples
--------
====== ====== ====== ======
var1 var2 givens output
====== ====== ====== ======
x + 1 x + 1 {} True
x + 1 y + 1 {} False
x + 1 y + 1 {x: y} True
====== ====== ====== ======
"""
use_equal_computations = True
if givens is None:
givens = {}
if not isinstance(givens, dict):
givens = dict(givens)
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok = True
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(graph_inputs, ([var1], [var2])))
all_vars = [
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = {
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
}
if (
inside[to_replace][0]
and not inside[to_replace][1]
and inside[replace_by][1]
and not inside[replace_by][0]
):
# Substitute variable in `var1` by one from `var2`.
in_xs.append(to_replace)
in_ys.append(replace_by)
elif (
inside[to_replace][1]
and not inside[to_replace][0]
and inside[replace_by][0]
and not inside[replace_by][1]
):
# Substitute variable in `var2` by one from `var1`.
in_xs.append(replace_by)
in_ys.append(to_replace)
else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False
else:
in_xs = None
in_ys = None
if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1
return rval1
import copy
import aesara
from aesara.graph.basic import equal_computations, graph_inputs, vars_between
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `aesara.graph.basic.is_same_graph`.
See help on `aesara.graph.basic.is_same_graph` for additional documentation.
"""
from aesara.graph.opt import MergeOptimizer
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = aesara.graph.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in fgraph.outputs]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def is_same_graph(var1, var2, givens=None):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
By 'performing the same computation', we mean that they must share the same
graph, so that for instance this function will return False when comparing
(x * (y * z)) with ((x * y) * z).
The current implementation is not efficient since, when possible, it
verifies equality by calling two different functions that are expected to
return the same output. The goal is to verify this assumption, to
eventually get rid of one of them in the future.
Parameters
----------
var1
The first Variable to compare.
var2
The second Variable to compare.
givens
Similar to the `givens` argument of `aesara.function`, it can be used
to perform substitutions in the computational graph of `var1` and
`var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable
is present in both.
Examples
--------
====== ====== ====== ======
var1 var2 givens output
====== ====== ====== ======
x + 1 x + 1 {} True
x + 1 y + 1 {} False
x + 1 y + 1 {x: y} True
====== ====== ====== ======
"""
use_equal_computations = True
if givens is None:
givens = {}
if not isinstance(givens, dict):
givens = dict(givens)
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok = True
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(graph_inputs, ([var1], [var2])))
all_vars = [
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = {
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
}
if (
inside[to_replace][0]
and not inside[to_replace][1]
and inside[replace_by][1]
and not inside[replace_by][0]
):
# Substitute variable in `var1` by one from `var2`.
in_xs.append(to_replace)
in_ys.append(replace_by)
elif (
inside[to_replace][1]
and not inside[to_replace][0]
and inside[replace_by][0]
and not inside[replace_by][1]
):
# Substitute variable in `var2` by one from `var1`.
in_xs.append(replace_by)
in_ys.append(to_replace)
else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False
else:
in_xs = None
in_ys = None
if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1
return rval1
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import NodeFinder, is_same_graph
from aesara.graph.features import NodeFinder
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.tensor.math import neg
from aesara.tensor.type import vectors
class TestNodeFinder:
......@@ -84,137 +82,3 @@ class TestNodeFinder:
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([t for t in g.get_nodes(type)]) == num:
raise Exception("Expected: %i times %s" % (num, type))
class TestIsSameGraph:
def check(self, expected):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, neg(x), (({}, False),)),
(x, neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
)
from aesara.graph.opt_utils import is_same_graph
from aesara.tensor.math import neg
from aesara.tensor.type import vectors
class TestIsSameGraph:
def check(self, expected):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, neg(x), (({}, False),)),
(x, neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
)
......@@ -17,9 +17,9 @@ from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.features import is_same_graph
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, out2in
from aesara.graph.opt_utils import is_same_graph
from aesara.graph.optdb import Query
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
......
......@@ -12,8 +12,8 @@ import aesara.tensor.basic as aet
from aesara.compile import DeepCopyOp, shared
from aesara.compile.io import In
from aesara.configdefaults import config
from aesara.graph.features import is_same_graph
from aesara.graph.op import get_test_value
from aesara.graph.opt_utils import is_same_graph
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import exp, isinf
from aesara.tensor.math import sum as aet_sum
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论