提交 459c570d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace unification framework with logical-unification, etuples, and cons

These changes only affect `PatternSub`, which now no longer allows constraints on non (logic) variable terms in its patterns. Likewise, the `values_eq_approx` and `skip_identities_fn` options are no longer supported.
上级 ac7986d7
...@@ -20,8 +20,6 @@ from functools import partial, reduce ...@@ -20,8 +20,6 @@ from functools import partial, reduce
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh from aesara.graph import destroyhandler as dh
...@@ -32,6 +30,7 @@ from aesara.graph.basic import ( ...@@ -32,6 +30,7 @@ from aesara.graph.basic import (
applys_between, applys_between,
io_toposort, io_toposort,
nodes_constructed, nodes_constructed,
vars_between,
) )
from aesara.graph.features import Feature, NodeFinder from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import FunctionGraph, InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
...@@ -1597,15 +1596,13 @@ class OpRemove(LocalOptimizer): ...@@ -1597,15 +1596,13 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer): class PatternSub(LocalOptimizer):
""" """Replace all occurrences of an input pattern with an output pattern.
@todo update The input and output patterns have the following syntax:
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...) input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>, input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>) constraint = <constraint>)
sub_pattern ::= input_pattern sub_pattern ::= input_pattern
sub_pattern ::= string sub_pattern ::= string
sub_pattern ::= a Constant instance sub_pattern ::= a Constant instance
...@@ -1651,8 +1648,6 @@ class PatternSub(LocalOptimizer): ...@@ -1651,8 +1648,6 @@ class PatternSub(LocalOptimizer):
skip_identities_fn : TODO skip_identities_fn : TODO
name : name :
Allows to override this optimizer name. Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional tracks : optional
The values that :meth:`self.tracks` will return. Useful to speed up The values that :meth:`self.tracks` will return. Useful to speed up
optimization sometimes. optimization sometimes.
...@@ -1675,7 +1670,7 @@ class PatternSub(LocalOptimizer): ...@@ -1675,7 +1670,7 @@ class PatternSub(LocalOptimizer):
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x', PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}), 'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x')) (scrabble, 'x'))
""" """
...@@ -1686,13 +1681,15 @@ class PatternSub(LocalOptimizer): ...@@ -1686,13 +1681,15 @@ class PatternSub(LocalOptimizer):
allow_multiple_clients=False, allow_multiple_clients=False,
skip_identities_fn=None, skip_identities_fn=None,
name=None, name=None,
pdb=False,
tracks=(), tracks=(),
get_nodes=None, get_nodes=None,
values_eq_approx=None, values_eq_approx=None,
): ):
self.in_pattern = in_pattern from aesara.graph.unify import convert_strs_to_vars
self.out_pattern = out_pattern
var_map = {}
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx self.values_eq_approx = values_eq_approx
if isinstance(in_pattern, (list, tuple)): if isinstance(in_pattern, (list, tuple)):
self.op = self.in_pattern[0] self.op = self.in_pattern[0]
...@@ -1709,7 +1706,6 @@ class PatternSub(LocalOptimizer): ...@@ -1709,7 +1706,6 @@ class PatternSub(LocalOptimizer):
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
self.__name__ = name self.__name__ = name
self.pdb = pdb
self._tracks = tracks self._tracks = tracks
self.get_nodes = get_nodes self.get_nodes = get_nodes
if tracks != (): if tracks != ():
...@@ -1729,7 +1725,16 @@ class PatternSub(LocalOptimizer): ...@@ -1729,7 +1725,16 @@ class PatternSub(LocalOptimizer):
If it does, it constructs ``out_pattern`` and performs the replacement. If it does, it constructs ``out_pattern`` and performs the replacement.
""" """
from aesara.graph import unify from etuples.core import ExpressionTuple
from unification import reify, unify
# TODO: We shouldn't need to iterate like this.
if not self.allow_multiple_clients and any(
len(fgraph.clients.get(v)) > 1
for v in vars_between(fgraph.inputs, node.outputs)
if v not in fgraph.inputs
):
return False
if get_nodes and self.get_nodes is not None: if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node): for real_node in self.get_nodes(fgraph, node):
...@@ -1741,99 +1746,16 @@ class PatternSub(LocalOptimizer): ...@@ -1741,99 +1746,16 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
# TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match
def retry_with_equiv():
if not self.skip_identities_fn:
return False
expr_equiv = self.skip_identities_fn(expr)
if expr_equiv is None:
return False
# TODO: Not sure how to handle multiple_clients flag
return match(
pattern,
expr_equiv,
u,
allow_multiple_clients=allow_multiple_clients,
)
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (
not allow_multiple_clients and len(fgraph.clients[expr]) > 1
):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u, self.allow_multiple_clients)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern["pattern"]
except KeyError:
raise KeyError(
f"Malformed pattern: {pattern} (expected key 'pattern')"
)
constraint = pattern.get("constraint", lambda expr: True)
if constraint(expr):
return match(
real_pattern,
expr,
u,
pattern.get("allow_multiple_clients", allow_multiple_clients),
)
else:
return retry_with_equiv()
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, Constant):
if np.all(aesara.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif (
isinstance(pattern, Constant)
and isinstance(expr, Constant)
and pattern.equals(expr)
):
return u
else:
return retry_with_equiv()
if pdb:
import pdb
pdb.set_trace() s = unify(self.in_pattern, node.out)
return u
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb) if s is False:
if not u:
return False return False
def build(pattern, u): ret = reify(self.out_pattern, s)
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args)
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int, float)):
return pattern
else:
return pattern.clone()
ret = build(self.out_pattern, u)
if isinstance(ret, (int, float)): if isinstance(ret, ExpressionTuple):
# TODO: Should we convert these to constants explicitly? ret = ret.evaled_obj
return [ret]
if self.values_eq_approx: if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx ret.tag.values_eq_approx = self.values_eq_approx
......
差异被折叠。
...@@ -45,7 +45,15 @@ Programming Language :: Python :: 3.9 ...@@ -45,7 +45,15 @@ Programming Language :: Python :: 3.9
""" """
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f] CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"] install_requires = [
"numpy>=1.17.0",
"scipy>=0.14",
"filelock",
"etuples",
"logical-unification",
"cons",
]
if sys.version_info[0:2] < (3, 7): if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"] install_requires += ["dataclasses"]
......
...@@ -162,7 +162,7 @@ class TestPatternOptimizer: ...@@ -162,7 +162,7 @@ class TestPatternOptimizer:
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g) TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g)
assert str(g) == "FunctionGraph(Op1(x))" assert str(g) == "FunctionGraph(Op1(x))"
def test_constant_unification(self): def test_constant(self):
x = Constant(MyType(), 2, name="x") x = Constant(MyType(), 2, name="x")
y = MyVariable("y") y = MyVariable("y")
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
...@@ -192,6 +192,9 @@ class TestPatternOptimizer: ...@@ -192,6 +192,9 @@ class TestPatternOptimizer:
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g) PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g)
assert str(g) == "FunctionGraph(Op3(x, x))" assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail(
reason="This pattern & constraint case isn't used and doesn't make much sense."
)
def test_match_same_illegal(self): def test_match_same_illegal(self):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
...@@ -206,9 +209,10 @@ class TestPatternOptimizer: ...@@ -206,9 +209,10 @@ class TestPatternOptimizer:
).optimize(g) ).optimize(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_multi(self): def test_allow_multiple_clients(self):
x, y, z = inputs() x, y, z = inputs()
e0 = op1(x, y) e0 = op1(x, y)
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0) e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
...@@ -224,17 +228,6 @@ class TestPatternOptimizer: ...@@ -224,17 +228,6 @@ class TestPatternOptimizer:
assert str_g == "FunctionGraph(Op4(z, y))" assert str_g == "FunctionGraph(Op4(z, y))"
# def test_multi_ingraph(self):
# # known to fail
# x, y, z = inputs()
# e0 = op1(x, y)
# e = op4(e0, e0)
# g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "FunctionGraph(Op3(x, y))"
def OpSubOptimizer(op1, op2): def OpSubOptimizer(op1, op2):
return OpKeyOptimizer(OpSub(op1, op2)) return OpKeyOptimizer(OpSub(op1, op2))
...@@ -454,6 +447,8 @@ class TestMergeOptimizer: ...@@ -454,6 +447,8 @@ class TestMergeOptimizer:
class TestEquilibrium: class TestEquilibrium:
def test_1(self): def test_1(self):
x, y, z = map(MyVariable, "xyz") x, y, z = map(MyVariable, "xyz")
# TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s
# and `__init__`s, so they can't be `etuplized` correctly
e = op3(op4(x, y)) e = op3(op4(x, y))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
# print g # print g
...@@ -632,8 +627,9 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): ...@@ -632,8 +627,9 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
assert output is x assert output is x
assert output.tag.values_eq_approx is values_eq_approx_always_true assert output.tag.values_eq_approx is values_eq_approx_always_true
else: else:
assert isinstance(output, Constant) # The replacement types do not match, so the substitution should've
assert not hasattr(output.tag, "value_eq_approx") # failed
assert output is e
@pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"]) @pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"])
......
import numpy as np
import pytest
from cons import car, cdr
from cons.core import ConsError
from etuples import apply, etuple, etuplize
from etuples.core import ExpressionTuple
from unification import reify, unify, var
from unification.variable import Var
import aesara.scalar as aes
import aesara.tensor as at
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.op import Op
from aesara.graph.unify import ConstrainedVar, convert_strs_to_vars
from aesara.tensor.type import TensorType
from tests.graph.utils import MyType
class CustomOp(Op):
__props__ = ("a",)
def __init__(self, a):
self.a = a
def make_node(self, *inputs):
return Apply(self, list(inputs), [at.vector()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class CustomOpNoPropsNoEq(Op):
def __init__(self, a):
self.a = a
def make_node(self, *inputs):
return Apply(self, list(inputs), [at.vector()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class CustomOpNoProps(CustomOpNoPropsNoEq):
def __eq__(self, other):
return type(self) == type(other) and self.a == other.a
def __hash__(self):
return hash((type(self), self.a))
def test_cons():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = x_at + y_at
res = car(z_at)
assert res == z_at.owner.op
res = cdr(z_at)
assert res == [x_at, y_at]
with pytest.raises(ConsError):
car(x_at)
with pytest.raises(ConsError):
cdr(x_at)
op1 = CustomOp(1)
assert car(op1) == CustomOp
assert cdr(op1) == (1,)
tt1 = TensorType("float32", [True, False])
assert car(tt1) == TensorType
assert cdr(tt1) == ("float32", (True, False))
op1_np = CustomOpNoProps(1)
with pytest.raises(ConsError):
car(op1_np)
with pytest.raises(ConsError):
cdr(op1_np)
atype_at = aes.float64
car_res = car(atype_at)
cdr_res = cdr(atype_at)
assert car_res is type(atype_at)
assert cdr_res == [atype_at.dtype]
atype_at = at.lvector
car_res = car(atype_at)
cdr_res = cdr(atype_at)
assert car_res is type(atype_at)
assert cdr_res == [atype_at.dtype, atype_at.broadcastable]
def test_etuples():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = etuple(x_at, y_at)
res = apply(at.add, z_at)
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]
w_at = etuple(at.add, x_at, y_at)
res = w_at.evaled_obj
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]
# This `Op` doesn't expand into an `etuple` (i.e. it's "atomic")
op1_np = CustomOpNoProps(1)
res = apply(op1_np, z_at)
assert res.owner.op == op1_np
q_at = op1_np(x_at, y_at)
res = etuplize(q_at)
assert res[0] == op1_np
with pytest.raises(TypeError):
etuplize(op1_np)
class MyMultiOutOp(Op):
def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)
def perform(self, node, inputs, outputs):
outputs[0] = np.array(inputs[0])
outputs[1] = np.array(inputs[0])
x_at = at.vector("x")
op1_np = MyMultiOutOp()
res = apply(op1_np, etuple(x_at))
assert len(res) == 2
assert res[0].owner.op == op1_np
assert res[1].owner.op == op1_np
def test_unify_Variable():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = x_at + y_at
# `Variable`, `Variable`
s = unify(z_at, z_at)
assert s == {}
# These `Variable`s have no owners
v1 = MyType()()
v2 = MyType()()
assert v1 != v2
s = unify(v1, v2)
assert s is False
op_lv = var()
z_pat_et = etuple(op_lv, x_at, y_at)
# `Variable`, `ExpressionTuple`
s = unify(z_at, z_pat_et, {})
assert op_lv in s
assert s[op_lv] == z_at.owner.op
res = reify(z_pat_et, s)
assert isinstance(res, ExpressionTuple)
assert equal_computations([res.evaled_obj], [z_at])
z_et = etuple(at.add, x_at, y_at)
# `ExpressionTuple`, `ExpressionTuple`
s = unify(z_et, z_pat_et, {})
assert op_lv in s
assert s[op_lv] == z_et[0]
res = reify(z_pat_et, s)
assert isinstance(res, ExpressionTuple)
assert equal_computations([res.evaled_obj], [z_et.evaled_obj])
# `ExpressionTuple`, `Variable`
s = unify(z_et, x_at, {})
assert s is False
# This `Op` doesn't expand into an `ExpressionTuple`
op1_np = CustomOpNoProps(1)
q_at = op1_np(x_at, y_at)
a_lv = var()
b_lv = var()
# `Variable`, `ExpressionTuple`
s = unify(q_at, etuple(op1_np, a_lv, b_lv))
assert s[a_lv] == x_at
assert s[b_lv] == y_at
def test_unify_Op():
# These `Op`s expand into `ExpressionTuple`s
op1 = CustomOp(1)
op2 = CustomOp(1)
# `Op`, `Op`
s = unify(op1, op2)
assert s == {}
# `ExpressionTuple`, `Op`
s = unify(etuplize(op1), op2)
assert s == {}
# These `Op`s don't expand into `ExpressionTuple`s
op1_np = CustomOpNoProps(1)
op2_np = CustomOpNoProps(1)
s = unify(op1_np, op2_np)
assert s == {}
# Same, but this one also doesn't implement `__eq__`
op1_np_neq = CustomOpNoPropsNoEq(1)
s = unify(op1_np_neq, etuplize(op1))
assert s is False
def test_unify_Constant():
# Make sure `Constant` unification works
c1_at = at.as_tensor(np.r_[1, 2])
c2_at = at.as_tensor(np.r_[1, 2])
# `Constant`, `Constant`
s = unify(c1_at, c2_at)
assert s == {}
def test_unify_Type():
t1 = TensorType(np.float64, (True, False))
t2 = TensorType(np.float64, (True, False))
# `Type`, `Type`
s = unify(t1, t2)
assert s == {}
# `Type`, `ExpressionTuple`
s = unify(t1, etuple(TensorType, "float64", (True, False)))
assert s == {}
from aesara.scalar.basic import Scalar
st1 = Scalar(np.float64)
st2 = Scalar(np.float64)
s = unify(st1, st2)
assert s == {}
def test_ConstrainedVar():
cvar = ConstrainedVar(lambda x: isinstance(x, str))
assert repr(cvar).startswith("ConstrainedVar(")
assert repr(cvar).endswith(f", {cvar.token})")
s = unify(cvar, 1)
assert s is False
s = unify(1, cvar)
assert s is False
s = unify(cvar, "hi")
assert s[cvar] == "hi"
s = unify("hi", cvar)
assert s[cvar] == "hi"
x_lv = var()
s = unify(cvar, x_lv)
assert s == {cvar: x_lv}
s = unify(cvar, x_lv, {x_lv: "hi"})
assert s[cvar] == "hi"
s = unify(x_lv, cvar, {x_lv: "hi"})
assert s[cvar] == "hi"
s_orig = {cvar: "hi", x_lv: "hi"}
s = unify(x_lv, cvar, s_orig)
assert s == s_orig
s_orig = {cvar: "hi", x_lv: "bye"}
s = unify(x_lv, cvar, s_orig)
assert s is False
x_at = at.vector("x")
y_at = at.vector("y")
op1_np = CustomOpNoProps(1)
r_at = etuple(op1_np, x_at, y_at)
def constraint(x):
return isinstance(x, tuple)
a_lv = ConstrainedVar(constraint)
res = reify(etuple(op1_np, a_lv), {a_lv: r_at})
assert res[1] == r_at
def test_convert_strs_to_vars():
res = convert_strs_to_vars("a")
assert isinstance(res, Var)
assert res.token == "a"
x_at = at.vector()
y_at = at.vector()
res = convert_strs_to_vars((("a", x_at), y_at))
assert res == etuple(etuple(var("a"), x_at), y_at)
def constraint(x):
return isinstance(x, str)
res = convert_strs_to_vars(
(({"pattern": "a", "constraint": constraint}, x_at), y_at)
)
assert res == etuple(etuple(ConstrainedVar(constraint, "a"), x_at), y_at)
# Make sure constrained logic variables are the same across distinct uses
# of their string names
res = convert_strs_to_vars(({"pattern": "a", "constraint": constraint}, "a"))
assert res[0] is res[1]
var_map = {"a": var("a")}
res = convert_strs_to_vars(("a",), var_map=var_map)
assert res[0] is var_map["a"]
# Make sure numbers and NumPy arrays are converted
val = np.r_[1, 2]
res = convert_strs_to_vars((val,))
assert isinstance(res[0], Constant)
assert np.array_equal(res[0].data, val)
...@@ -7,7 +7,7 @@ from aesara.graph.type import Type ...@@ -7,7 +7,7 @@ from aesara.graph.type import Type
def is_variable(x): def is_variable(x):
if not isinstance(x, Variable): if not isinstance(x, Variable):
raise TypeError("not a Variable", x) raise TypeError(f"not a Variable: {x}")
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论