提交 459c570d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace unification framework with logical-unification, etuples, and cons

These changes only affect `PatternSub`, which now no longer allows constraints on non (logic) variable terms in its patterns. Likewise, the `values_eq_approx` and `skip_identities_fn` options are no longer supported.
上级 ac7986d7
......@@ -20,8 +20,6 @@ from functools import partial, reduce
from itertools import chain
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh
......@@ -32,6 +30,7 @@ from aesara.graph.basic import (
applys_between,
io_toposort,
nodes_constructed,
vars_between,
)
from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import FunctionGraph, InconsistencyError
......@@ -1597,15 +1596,13 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer):
"""
"""Replace all occurrences of an input pattern with an output pattern.
@todo update
Replaces all occurrences of the input pattern by the output pattern:
The input and output patterns have the following syntax:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
......@@ -1651,8 +1648,6 @@ class PatternSub(LocalOptimizer):
skip_identities_fn : TODO
name :
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that :meth:`self.tracks` will return. Useful to speed up
optimization sometimes.
......@@ -1675,7 +1670,7 @@ class PatternSub(LocalOptimizer):
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
(scrabble, 'x'))
"""
......@@ -1686,13 +1681,15 @@ class PatternSub(LocalOptimizer):
allow_multiple_clients=False,
skip_identities_fn=None,
name=None,
pdb=False,
tracks=(),
get_nodes=None,
values_eq_approx=None,
):
self.in_pattern = in_pattern
self.out_pattern = out_pattern
from aesara.graph.unify import convert_strs_to_vars
var_map = {}
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
self.values_eq_approx = values_eq_approx
if isinstance(in_pattern, (list, tuple)):
self.op = self.in_pattern[0]
......@@ -1709,7 +1706,6 @@ class PatternSub(LocalOptimizer):
self.skip_identities_fn = skip_identities_fn
if name:
self.__name__ = name
self.pdb = pdb
self._tracks = tracks
self.get_nodes = get_nodes
if tracks != ():
......@@ -1729,7 +1725,16 @@ class PatternSub(LocalOptimizer):
If it does, it constructs ``out_pattern`` and performs the replacement.
"""
from aesara.graph import unify
from etuples.core import ExpressionTuple
from unification import reify, unify
# TODO: We shouldn't need to iterate like this.
if not self.allow_multiple_clients and any(
len(fgraph.clients.get(v)) > 1
for v in vars_between(fgraph.inputs, node.outputs)
if v not in fgraph.inputs
):
return False
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(fgraph, node):
......@@ -1741,99 +1746,16 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
# TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match
def retry_with_equiv():
if not self.skip_identities_fn:
return False
expr_equiv = self.skip_identities_fn(expr)
if expr_equiv is None:
return False
# TODO: Not sure how to handle multiple_clients flag
return match(
pattern,
expr_equiv,
u,
allow_multiple_clients=allow_multiple_clients,
)
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (
not allow_multiple_clients and len(fgraph.clients[expr]) > 1
):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u, self.allow_multiple_clients)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern["pattern"]
except KeyError:
raise KeyError(
f"Malformed pattern: {pattern} (expected key 'pattern')"
)
constraint = pattern.get("constraint", lambda expr: True)
if constraint(expr):
return match(
real_pattern,
expr,
u,
pattern.get("allow_multiple_clients", allow_multiple_clients),
)
else:
return retry_with_equiv()
elif isinstance(pattern, str):
v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr:
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif isinstance(pattern, (int, float)) and isinstance(expr, Constant):
if np.all(aesara.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif (
isinstance(pattern, Constant)
and isinstance(expr, Constant)
and pattern.equals(expr)
):
return u
else:
return retry_with_equiv()
if pdb:
import pdb
pdb.set_trace()
return u
s = unify(self.in_pattern, node.out)
u = match(self.in_pattern, node.out, unify.Unification(), True, self.pdb)
if not u:
if s is False:
return False
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args)
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
elif isinstance(pattern, (int, float)):
return pattern
else:
return pattern.clone()
ret = build(self.out_pattern, u)
ret = reify(self.out_pattern, s)
if isinstance(ret, (int, float)):
# TODO: Should we convert these to constants explicitly?
return [ret]
if isinstance(ret, ExpressionTuple):
ret = ret.evaled_obj
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
......
......@@ -10,599 +10,284 @@ that satisfies the constraints. That's useful for pattern matching.
"""
from copy import copy
from functools import partial
from typing import Dict
from collections.abc import Mapping
from numbers import Number
from typing import Dict, Optional, Tuple, Union
import numpy as np
from cons.core import ConsError, _car, _cdr
from etuples import apply, etuple, etuplize
from etuples.core import ExpressionTuple
from unification.core import _unify, assoc
from unification.utils import transitive_get as walk
from unification.variable import Var, isvar, var
class Keyword:
def __init__(self, name, nonzero=True):
self.name = name
self.nonzero = nonzero
from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Type
def __nonzero__(self):
# Python 2.x
return self.__bool__()
def __bool__(self):
# Python 3.x
return self.nonzero
def __str__(self):
return f"<{self.name}>"
def __repr__(self):
return f"<{self.name}>"
ABORT = Keyword("ABORT", False)
RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = (int, str, float, bool, type(None), Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) and (
type2 is ANY_TYPE or isinstance(arg2, type2)
):
pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) and (
type2 is ANY_TYPE or isinstance(arg1, type2)
):
arg1, arg2 = arg2, arg1
else:
return old_f(arg1, arg2, *rest)
variable = f(arg1, arg2, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, arg2, *rest)
else:
return variable
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1, type2)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, *rest):
if type1 is ANY_TYPE or isinstance(arg1, type1):
variable = f(arg1, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, *rest)
else:
return variable
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1,)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
class Variable:
"""
Serves as a base class of variables for the purpose of unification.
"Unification" here basically means matching two patterns, see the
module-level docstring.
Behavior for unifying various types of variables should be added as
overloadings of the 'unify' function.
Notes
-----
There are two Variable classes in aesara and this is the more rarely used
one.
This class is used internally by the PatternSub optimization,
and possibly other subroutines that have to perform graph queries.
If that doesn't sound like what you're doing, the Variable class you
want is probably aesara.graph.basic.Variable.
"""
def __init__(self, name="?"):
self.name = name
def __str__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
"{}={}".format(key, value) for key, value in self.__dict__.items()
)
+ ")"
)
def __repr__(self):
return str(self)
def eval_if_etuple(x):
if isinstance(x, ExpressionTuple):
return x.evaled_obj
return x
class FreeVariable(Variable):
"""
This Variable can take any value.
"""
class BoundVariable(Variable):
"""
This Variable is bound to a value accessible via the value field.
class ConstrainedVar(Var):
"""A logical variable with a constraint.
These will unify with other `Var`s regardless of the constraints.
"""
def __init__(self, name, value):
super().__init__(name=name)
self.value = value
__slots__ = ("constraint",)
def __new__(cls, constraint, token=None, prefix=""):
if token is None:
token = f"{prefix}_{Var._id}"
Var._id += 1
class OrVariable(Variable):
"""
This Variable could be any value from a finite list of values,
accessible via the options field.
"""
key = (token, constraint)
obj = cls._refs.get(key, None)
def __init__(self, name, options):
super().__init__(name=name)
self.options = options
if obj is None:
obj = object.__new__(cls)
obj.token = token
obj.constraint = constraint
cls._refs[key] = obj
return obj
class NotVariable(Variable):
"""
This Variable can take any value but a finite amount of forbidden
values, accessible via the not_options field.
"""
def __eq__(self, other):
if type(self) == type(other):
return self.token == other.token and self.constraint == other.constraint
return NotImplemented
def __init__(self, name, not_options):
super().__init__(name=name)
self.not_options = not_options
class VariableInList: # not a subclass of Variable
"""
This special kind of variable is matched against a list and unifies
an inner Variable to an OrVariable of the values in the list.
For example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
the 'x' variable is unified to an OrVariable('?', [1,2,3]).
"""
def __init__(self, variable):
self.variable = variable
def __hash__(self):
return hash((type(self), self.token, self.constraint))
def __str__(self):
return f"~{self.token} [{self.constraint}]"
_all: Dict = {}
def __repr__(self):
return f"ConstrainedVar({repr(self.constraint)}, {self.token})"
def var_lookup(vartype, name, *args, **kwargs):
sig = (vartype, name)
if sig in _all:
return _all[sig]
def car_Variable(x):
if x.owner:
return x.owner.op
else:
v = vartype(name, *args)
_all[sig] = v
return v
raise ConsError("Not a cons pair.")
Var = partial(var_lookup, FreeVariable)
V = Var
OrV = partial(var_lookup, OrVariable)
NV = partial(var_lookup, NotVariable)
_car.add((Variable,), car_Variable)
class Unification:
"""
This class represents a possible unification of a group of variables
with each other or with tangible values.
Parameters
----------
inplace : bool
If inplace is False, the merge method will return a new Unification
that is independent from the previous one (which allows backtracking).
"""
def __init__(self, inplace=False):
self.unif = {}
self.inplace = inplace
def merge(self, new_best, *vars):
"""
Links all the specified vars to a Variable that represents their
unification.
"""
if self.inplace:
U = self
else:
# Copy all the unification data.
U = Unification(self.inplace)
for var, (best, pool) in self.unif.items():
# The pool of a variable is the set of all the variables that
# are unified to it (all the variables that must have the same
# value). The best is the Variable that represents a set of
# values common to all the variables in the pool.
U.unif[var] = (best, pool)
# We create a new pool for our new set of unified variables, initially
# containing vars and new_best
new_pool = set(vars)
new_pool.add(new_best)
for var in copy(new_pool):
best, pool = U.unif.get(var, (var, set()))
# We now extend the new pool to contain the pools of all the variables.
new_pool.update(pool)
# All variables get the new pool.
for var in new_pool:
U.unif[var] = (new_best, new_pool)
return U
def __getitem__(self, v):
"""
For a variable v, returns a Variable that represents the tightest
set of possible values it can take.
"""
return self.unif.get(v, (v, None))[0]
def unify_walk(a, b, U):
"""
unify_walk(a, b, U) returns an Unification where a and b are unified,
given the unification that already exists in the Unification U. If the
unification fails, it returns False.
There are two ways to expand the functionality of unify_walk. The first way
is:
@comm_guard(type_of_a, type_of_b)
def unify_walk(a, b, U):
...
A function defined as such will be executed whenever the types of a and b
match the declaration. Note that comm_guard automatically guarantees that
your function is commutative: it will try to match the types of a, b or
b, a.
It is recommended to define unify_walk in that fashion for new types of
Variable because different types of Variable interact a lot with each other,
e.g. when unifying an OrVariable with a NotVariable, etc. You can return
the special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are
tried in the reverse order of their declaration.
Another way is to override __unify_walk__ in an user-defined class.
Limitations: cannot embed a Variable in another (the functionality could
be added if required)
Here is a list of unification rules with their associated behavior:
"""
if a.__class__ != b.__class__:
return False
elif a == b:
return U
def cdr_Variable(x):
if x.owner:
x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x)
else:
return False
raise ConsError("Not a cons pair.")
return x_e[1:]
@comm_guard(FreeVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object).
"""
v = BoundVariable("?", o)
return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object.
"""
if bv.value == o:
return U
else:
return False
_cdr.add((Variable,), cdr_Variable)
@comm_guard(OrVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options.
def car_Op(x):
if hasattr(x, "__props__"):
return type(x)
"""
if o in ov.options:
v = BoundVariable("?", o)
return U.merge(v, ov)
else:
return False
raise ConsError("Not a cons pair.")
@comm_guard(NotVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options.
_car.add((Op,), car_Op)
"""
if o in nv.not_options:
return False
else:
v = BoundVariable("?", o)
return U.merge(v, nv)
def cdr_Op(x):
if not hasattr(x, "__props__"):
raise ConsError("Not a cons pair.")
@comm_guard(FreeVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(fv, v, U):
"""
Both variables are unified.
x_e = etuple(
_car(x),
*[getattr(x, p) for p in getattr(x, "__props__", ())],
evaled_obj=x,
)
return x_e[1:]
"""
v = U[v]
return U.merge(v, fv)
_cdr.add((Op,), cdr_Op)
@comm_guard(BoundVariable, Variable) # type: ignore[no-redef] # noqa
def unify_walk(bv, v, U):
"""
V is unified to BV.value.
"""
return unify_walk(v, bv.value, U)
def car_Type(x):
return type(x)
@comm_guard(OrVariable, OrVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
_car.add((Type,), car_Type)
"""
opt = a.options.intersection(b.options)
if not opt:
return False
elif len(opt) == 1:
v = BoundVariable("?", opt[0])
else:
v = OrVariable("?", opt)
return U.merge(v, a, b)
def cdr_Type(x):
x_e = etuple(
_car(x), *[getattr(x, p) for p in getattr(x, "__props__", ())], evaled_obj=x
)
return x_e[1:]
@comm_guard(NotVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt = a.not_options.union(b.not_options)
v = NotVariable("?", opt)
return U.merge(v, a, b)
_cdr.add((Type,), cdr_Type)
@comm_guard(OrVariable, NotVariable) # type: ignore[no-redef] # noqa
def unify_walk(o, n, U):
r"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
def apply_Op_ExpressionTuple(op, etuple_arg):
res = op.make_node(*etuple_arg)
"""
opt = [x for x in o.options if x not in n.not_options]
if not opt:
return False
elif len(opt) == 1:
v = BoundVariable("?", opt[0])
else:
v = OrVariable("?", opt)
return U.merge(v, o, n)
try:
return res.default_output()
except ValueError:
return res.outputs
@comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
apply.add((Op, ExpressionTuple), apply_Op_ExpressionTuple)
"""
v = vil.variable
ov = OrVariable("?", l)
return unify_walk(v, ov, U)
def _unify_etuplize_first_arg(u, v, s):
try:
u_et = etuplize(u, shallow=True)
yield _unify(u_et, v, s)
except TypeError:
yield False
return
@comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
"""
if len(l1) != len(l2):
return False
for x1, x2 in zip(l1, l2):
U = unify_walk(x1, x2, U)
if U is False:
return False
return U
_unify.add((Op, ExpressionTuple, Mapping), _unify_etuplize_first_arg)
_unify.add(
(ExpressionTuple, Op, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s)
)
_unify.add((Type, ExpressionTuple, Mapping), _unify_etuplize_first_arg)
_unify.add(
(ExpressionTuple, Type, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s)
)
@comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
"""
for (k1, v1) in d1.items():
if k1 in d2:
U = unify_walk(v1, d2[k1], U)
if U is False:
return False
return U
def _unify_Variable_Variable(u, v, s):
# Avoid converting to `etuple`s, when possible
if u == v:
yield s
return
if not u.owner and not v.owner:
yield False
return
@comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
the objects.
yield _unify(
etuplize(u, shallow=True) if u.owner else u,
etuplize(v, shallow=True) if v.owner else v,
s,
)
"""
if (
not isinstance(a, Variable)
and not isinstance(b, Variable)
and hasattr(a, "__unify_walk__")
):
return a.__unify_walk__(b, U)
else:
return FALL_THROUGH
_unify.add((Variable, Variable, Mapping), _unify_Variable_Variable)
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_walk(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
return unify_walk(
o, best_v, U
) # reverse argument order so if o is a Variable this block of code is run again
def _unify_Constant_Constant(u, v, s):
# XXX: This ignores shape and type differences. It's only implemented this
# way for backward compatibility
if np.array_equiv(u.data, v.data):
yield s
else:
return FALL_THROUGH # call the next version of unify_walk that matches the type signature
class FVar:
def __init__(self, fn, *args):
self.fn = fn
self.args = args
def __call__(self, u):
return self.fn(*[unify_build(arg, u) for arg in self.args])
def unify_merge(a, b, U):
return a
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U):
return v
yield False
@comm_guard(BoundVariable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(bv, o, U):
return bv.value
_unify.add((Constant, Constant, Mapping), _unify_Constant_Constant)
@comm_guard(VariableInList, (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(vil, l, U):
return [unify_merge(x, x, U) for x in l]
def _unify_Variable_ExpressionTuple(u, v, s):
# `Constant`s are "atomic"
if not u.owner:
yield False
return
@comm_guard((list, tuple), (list, tuple)) # type: ignore[no-redef] # noqa
def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
yield _unify(etuplize(u, shallow=True), v, s)
@comm_guard(dict, dict) # type: ignore[no-redef] # noqa
def unify_merge(d1, d2, U):
d = d1.__class__()
for k1, v1 in d1.items():
if k1 in d2:
d[k1] = unify_merge(v1, d2[k1], U)
else:
d[k1] = unify_merge(v1, v1, U)
for k2, v2 in d2.items():
if k2 not in d1:
d[k2] = unify_merge(v2, v2, U)
return d
_unify.add(
(Variable, ExpressionTuple, Mapping),
_unify_Variable_ExpressionTuple,
)
_unify.add(
(ExpressionTuple, Variable, Mapping),
lambda u, v, s: _unify_Variable_ExpressionTuple(v, u, s),
)
@comm_guard(FVar, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(vs, o, U):
return vs(U)
@_unify.register(ConstrainedVar, (ConstrainedVar, Var, object), Mapping)
def _unify_ConstrainedVar_object(u, v, s):
u_w = walk(u, s)
@comm_guard(ANY_TYPE, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(a, b, U):
if (
not isinstance(a, Variable)
and not isinstance(b, Variable)
and hasattr(a, "__unify_merge__")
):
return a.__unify_merge__(b, U)
if isvar(v):
v_w = walk(v, s)
else:
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) # type: ignore[no-redef] # noqa
def unify_merge(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
return unify_merge(
o, best_v, U
) # reverse argument order so if o is a Variable this block of code is run again
else:
return FALL_THROUGH # call the next version of unify_walk that matches the type signature
def unify_build(x, U):
return unify_merge(x, x, U)
def unify(a, b):
U = unify_walk(a, b, Unification())
if not U:
return None, False
v_w = v
if u_w == v_w:
yield s
elif isvar(u_w):
if (
not isvar(v_w)
and isinstance(u_w, ConstrainedVar)
and not u_w.constraint(eval_if_etuple(v_w))
):
yield False
return
yield assoc(s, u_w, v_w)
elif isvar(v_w):
if (
not isvar(u_w)
and isinstance(v_w, ConstrainedVar)
and not v_w.constraint(eval_if_etuple(u_w))
):
yield False
return
yield assoc(s, v_w, u_w)
else:
return unify_merge(a, b, U), U
yield _unify(u_w, v_w, s)
_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object)
def convert_strs_to_vars(
x: Union[Tuple, str, Dict], var_map: Optional[Dict[str, Var]] = None
) -> Union[ExpressionTuple, Var]:
r"""Convert tuples and strings to `etuple`\s and logic variables, respectively.
Constrained logic variables are specified via `dict`s with the keys
`"pattern"`, which specifies the logic variable as a string, and
`"constraint"`, which provides the `Callable` constraint.
"""
if var_map is None:
var_map = {}
def _convert(y):
if isinstance(y, str):
v = var_map.get(y, var(y))
var_map[y] = v
return v
elif isinstance(y, dict):
pattern = y["pattern"]
if not isinstance(pattern, str):
raise TypeError(
"Constraints can only be assigned to logic variables (i.e. strings)"
)
constraint = y["constraint"]
v = var_map.get(pattern, ConstrainedVar(constraint, pattern))
var_map[pattern] = v
return v
elif isinstance(y, tuple):
return etuple(*tuple(_convert(e) for e in y))
elif isinstance(y, (Number, np.ndarray)):
from aesara.tensor import as_tensor_variable
return as_tensor_variable(y)
return y
return _convert(x)
......@@ -45,7 +45,15 @@ Programming Language :: Python :: 3.9
"""
CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f]
install_requires = ["numpy>=1.17.0", "scipy>=0.14", "filelock"]
install_requires = [
"numpy>=1.17.0",
"scipy>=0.14",
"filelock",
"etuples",
"logical-unification",
"cons",
]
if sys.version_info[0:2] < (3, 7):
install_requires += ["dataclasses"]
......
......@@ -162,7 +162,7 @@ class TestPatternOptimizer:
TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g)
assert str(g) == "FunctionGraph(Op1(x))"
def test_constant_unification(self):
def test_constant(self):
x = Constant(MyType(), 2, name="x")
y = MyVariable("y")
z = Constant(MyType(), 2, name="z")
......@@ -192,6 +192,9 @@ class TestPatternOptimizer:
PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g)
assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail(
reason="This pattern & constraint case isn't used and doesn't make much sense."
)
def test_match_same_illegal(self):
x, y, z = inputs()
e = op2(op1(x, x), op1(x, y))
......@@ -206,9 +209,10 @@ class TestPatternOptimizer:
).optimize(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
def test_multi(self):
def test_allow_multiple_clients(self):
x, y, z = inputs()
e0 = op1(x, y)
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
e = op3(op4(e0), e0)
g = FunctionGraph([x, y, z], [e])
PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g)
......@@ -224,17 +228,6 @@ class TestPatternOptimizer:
assert str_g == "FunctionGraph(Op4(z, y))"
# def test_multi_ingraph(self):
# # known to fail
# x, y, z = inputs()
# e0 = op1(x, y)
# e = op4(e0, e0)
# g = FunctionGraph([x, y, z], [e])
# PatternOptimizer((op4, (op1, 'x', 'y'), (op1, 'x', 'y')),
# (op3, 'x', 'y')).optimize(g)
# assert str(g) == "FunctionGraph(Op3(x, y))"
def OpSubOptimizer(op1, op2):
return OpKeyOptimizer(OpSub(op1, op2))
......@@ -454,6 +447,8 @@ class TestMergeOptimizer:
class TestEquilibrium:
def test_1(self):
x, y, z = map(MyVariable, "xyz")
# TODO FIXME: These `Op`s don't have matching/consistent `__prop__`s
# and `__init__`s, so they can't be `etuplized` correctly
e = op3(op4(x, y))
g = FunctionGraph([x, y, z], [e])
# print g
......@@ -632,8 +627,9 @@ def test_patternsub_values_eq_approx(out_pattern, tracks):
assert output is x
assert output.tag.values_eq_approx is values_eq_approx_always_true
else:
assert isinstance(output, Constant)
assert not hasattr(output.tag, "value_eq_approx")
# The replacement types do not match, so the substitution should've
# failed
assert output is e
@pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"])
......
import numpy as np
import pytest
from cons import car, cdr
from cons.core import ConsError
from etuples import apply, etuple, etuplize
from etuples.core import ExpressionTuple
from unification import reify, unify, var
from unification.variable import Var
import aesara.scalar as aes
import aesara.tensor as at
from aesara.graph.basic import Apply, Constant, equal_computations
from aesara.graph.op import Op
from aesara.graph.unify import ConstrainedVar, convert_strs_to_vars
from aesara.tensor.type import TensorType
from tests.graph.utils import MyType
class CustomOp(Op):
__props__ = ("a",)
def __init__(self, a):
self.a = a
def make_node(self, *inputs):
return Apply(self, list(inputs), [at.vector()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class CustomOpNoPropsNoEq(Op):
def __init__(self, a):
self.a = a
def make_node(self, *inputs):
return Apply(self, list(inputs), [at.vector()])
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class CustomOpNoProps(CustomOpNoPropsNoEq):
def __eq__(self, other):
return type(self) == type(other) and self.a == other.a
def __hash__(self):
return hash((type(self), self.a))
def test_cons():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = x_at + y_at
res = car(z_at)
assert res == z_at.owner.op
res = cdr(z_at)
assert res == [x_at, y_at]
with pytest.raises(ConsError):
car(x_at)
with pytest.raises(ConsError):
cdr(x_at)
op1 = CustomOp(1)
assert car(op1) == CustomOp
assert cdr(op1) == (1,)
tt1 = TensorType("float32", [True, False])
assert car(tt1) == TensorType
assert cdr(tt1) == ("float32", (True, False))
op1_np = CustomOpNoProps(1)
with pytest.raises(ConsError):
car(op1_np)
with pytest.raises(ConsError):
cdr(op1_np)
atype_at = aes.float64
car_res = car(atype_at)
cdr_res = cdr(atype_at)
assert car_res is type(atype_at)
assert cdr_res == [atype_at.dtype]
atype_at = at.lvector
car_res = car(atype_at)
cdr_res = cdr(atype_at)
assert car_res is type(atype_at)
assert cdr_res == [atype_at.dtype, atype_at.broadcastable]
def test_etuples():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = etuple(x_at, y_at)
res = apply(at.add, z_at)
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]
w_at = etuple(at.add, x_at, y_at)
res = w_at.evaled_obj
assert res.owner.op == at.add
assert res.owner.inputs == [x_at, y_at]
# This `Op` doesn't expand into an `etuple` (i.e. it's "atomic")
op1_np = CustomOpNoProps(1)
res = apply(op1_np, z_at)
assert res.owner.op == op1_np
q_at = op1_np(x_at, y_at)
res = etuplize(q_at)
assert res[0] == op1_np
with pytest.raises(TypeError):
etuplize(op1_np)
class MyMultiOutOp(Op):
def make_node(self, *inputs):
outputs = [MyType()(), MyType()()]
return Apply(self, list(inputs), outputs)
def perform(self, node, inputs, outputs):
outputs[0] = np.array(inputs[0])
outputs[1] = np.array(inputs[0])
x_at = at.vector("x")
op1_np = MyMultiOutOp()
res = apply(op1_np, etuple(x_at))
assert len(res) == 2
assert res[0].owner.op == op1_np
assert res[1].owner.op == op1_np
def test_unify_Variable():
x_at = at.vector("x")
y_at = at.vector("y")
z_at = x_at + y_at
# `Variable`, `Variable`
s = unify(z_at, z_at)
assert s == {}
# These `Variable`s have no owners
v1 = MyType()()
v2 = MyType()()
assert v1 != v2
s = unify(v1, v2)
assert s is False
op_lv = var()
z_pat_et = etuple(op_lv, x_at, y_at)
# `Variable`, `ExpressionTuple`
s = unify(z_at, z_pat_et, {})
assert op_lv in s
assert s[op_lv] == z_at.owner.op
res = reify(z_pat_et, s)
assert isinstance(res, ExpressionTuple)
assert equal_computations([res.evaled_obj], [z_at])
z_et = etuple(at.add, x_at, y_at)
# `ExpressionTuple`, `ExpressionTuple`
s = unify(z_et, z_pat_et, {})
assert op_lv in s
assert s[op_lv] == z_et[0]
res = reify(z_pat_et, s)
assert isinstance(res, ExpressionTuple)
assert equal_computations([res.evaled_obj], [z_et.evaled_obj])
# `ExpressionTuple`, `Variable`
s = unify(z_et, x_at, {})
assert s is False
# This `Op` doesn't expand into an `ExpressionTuple`
op1_np = CustomOpNoProps(1)
q_at = op1_np(x_at, y_at)
a_lv = var()
b_lv = var()
# `Variable`, `ExpressionTuple`
s = unify(q_at, etuple(op1_np, a_lv, b_lv))
assert s[a_lv] == x_at
assert s[b_lv] == y_at
def test_unify_Op():
# These `Op`s expand into `ExpressionTuple`s
op1 = CustomOp(1)
op2 = CustomOp(1)
# `Op`, `Op`
s = unify(op1, op2)
assert s == {}
# `ExpressionTuple`, `Op`
s = unify(etuplize(op1), op2)
assert s == {}
# These `Op`s don't expand into `ExpressionTuple`s
op1_np = CustomOpNoProps(1)
op2_np = CustomOpNoProps(1)
s = unify(op1_np, op2_np)
assert s == {}
# Same, but this one also doesn't implement `__eq__`
op1_np_neq = CustomOpNoPropsNoEq(1)
s = unify(op1_np_neq, etuplize(op1))
assert s is False
def test_unify_Constant():
# Make sure `Constant` unification works
c1_at = at.as_tensor(np.r_[1, 2])
c2_at = at.as_tensor(np.r_[1, 2])
# `Constant`, `Constant`
s = unify(c1_at, c2_at)
assert s == {}
def test_unify_Type():
t1 = TensorType(np.float64, (True, False))
t2 = TensorType(np.float64, (True, False))
# `Type`, `Type`
s = unify(t1, t2)
assert s == {}
# `Type`, `ExpressionTuple`
s = unify(t1, etuple(TensorType, "float64", (True, False)))
assert s == {}
from aesara.scalar.basic import Scalar
st1 = Scalar(np.float64)
st2 = Scalar(np.float64)
s = unify(st1, st2)
assert s == {}
def test_ConstrainedVar():
cvar = ConstrainedVar(lambda x: isinstance(x, str))
assert repr(cvar).startswith("ConstrainedVar(")
assert repr(cvar).endswith(f", {cvar.token})")
s = unify(cvar, 1)
assert s is False
s = unify(1, cvar)
assert s is False
s = unify(cvar, "hi")
assert s[cvar] == "hi"
s = unify("hi", cvar)
assert s[cvar] == "hi"
x_lv = var()
s = unify(cvar, x_lv)
assert s == {cvar: x_lv}
s = unify(cvar, x_lv, {x_lv: "hi"})
assert s[cvar] == "hi"
s = unify(x_lv, cvar, {x_lv: "hi"})
assert s[cvar] == "hi"
s_orig = {cvar: "hi", x_lv: "hi"}
s = unify(x_lv, cvar, s_orig)
assert s == s_orig
s_orig = {cvar: "hi", x_lv: "bye"}
s = unify(x_lv, cvar, s_orig)
assert s is False
x_at = at.vector("x")
y_at = at.vector("y")
op1_np = CustomOpNoProps(1)
r_at = etuple(op1_np, x_at, y_at)
def constraint(x):
return isinstance(x, tuple)
a_lv = ConstrainedVar(constraint)
res = reify(etuple(op1_np, a_lv), {a_lv: r_at})
assert res[1] == r_at
def test_convert_strs_to_vars():
res = convert_strs_to_vars("a")
assert isinstance(res, Var)
assert res.token == "a"
x_at = at.vector()
y_at = at.vector()
res = convert_strs_to_vars((("a", x_at), y_at))
assert res == etuple(etuple(var("a"), x_at), y_at)
def constraint(x):
return isinstance(x, str)
res = convert_strs_to_vars(
(({"pattern": "a", "constraint": constraint}, x_at), y_at)
)
assert res == etuple(etuple(ConstrainedVar(constraint, "a"), x_at), y_at)
# Make sure constrained logic variables are the same across distinct uses
# of their string names
res = convert_strs_to_vars(({"pattern": "a", "constraint": constraint}, "a"))
assert res[0] is res[1]
var_map = {"a": var("a")}
res = convert_strs_to_vars(("a",), var_map=var_map)
assert res[0] is var_map["a"]
# Make sure numbers and NumPy arrays are converted
val = np.r_[1, 2]
res = convert_strs_to_vars((val,))
assert isinstance(res[0], Constant)
assert np.array_equal(res[0].data, val)
......@@ -7,7 +7,7 @@ from aesara.graph.type import Type
def is_variable(x):
if not isinstance(x, Variable):
raise TypeError("not a Variable", x)
raise TypeError(f"not a Variable: {x}")
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论