提交 980ecacf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add AtomicVariable and NominalVariable classes

上级 96366f99
"""Core graph classes."""
import abc
import warnings
from collections import deque
from copy import copy
......@@ -10,6 +11,7 @@ from typing import (
Deque,
Dict,
Generator,
Hashable,
Iterable,
Iterator,
List,
......@@ -396,6 +398,14 @@ class Variable(Node):
def owner(self, value) -> None:
self._owner = value
@property
def index(self):
return self._index
@index.setter
def index(self, value):
self._index = value
def __init__(
self,
type,
......@@ -411,7 +421,7 @@ class Variable(Node):
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self._owner = owner
self.owner = owner
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
......@@ -586,7 +596,98 @@ class Variable(Node):
return d
class Constant(Variable):
class AtomicVariable(Variable):
"""A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type, **kwargs):
super().__init__(type, **kwargs)
@abc.abstractmethod
def signature(self):
...
def merge_signature(self):
return self.signature()
def equals(self, other):
"""
This does what `__eq__` would normally do, but `Variable` and `Apply`
should always be hashable by `id`.
"""
return isinstance(other, type(self)) and self.signature() == other.signature()
@property
def owner(self):
return None
@owner.setter
def owner(self, value):
if value is not None:
raise ValueError("AtomicVariable instances cannot have an owner.")
@property
def index(self):
return None
@index.setter
def index(self, value):
if value is not None:
raise ValueError("AtomicVariable instances cannot have an index.")
class NominalVariable(AtomicVariable):
"""A variable that enables alpha-equivalent comparisons."""
__instances__: Dict[Hashable, type] = {}
def __new__(cls, id, typ, **kwargs):
if (id, typ) not in cls.__instances__:
var_type = typ.variable_type
type_name = f"Nominal{var_type.__name__}"
def _reduce(self):
return cls, (self.id, self.type)
def _str(self):
return f"*{self.id}-{var_type.__str__(self)}"
new_type = type(
type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str}
)
res = super().__new__(new_type)
cls.__instances__[(id, typ)] = res
return cls.__instances__[(id, typ)]
def __init__(self, id, typ, **kwargs):
self.id = id
super().__init__(typ, **kwargs)
def clone(self):
return self
def __eq__(self, other):
if self is other:
return True
return (
type(self) == type(other)
and self.id == other.id
and self.type == other.type
)
def __hash__(self):
return hash((type(self), self.id, self.type))
def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def signature(self):
return (self.type, self.id)
class Constant(AtomicVariable):
"""A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant
......@@ -602,23 +703,16 @@ class Constant(Variable):
# __slots__ = ['data']
def __init__(self, type, data, name=None):
super().__init__(type, None, None, name)
super().__init__(type, name=name)
self.data = type.filter(data)
add_tag_trace(self)
def get_test_value(self):
return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self):
return (self.type, self.data)
def merge_signature(self):
return self.signature()
def __str__(self):
if self.name is not None:
return self.name
......@@ -641,9 +735,9 @@ class Constant(Variable):
if value is not None:
raise ValueError("Constant instances cannot have an owner.")
value = property(lambda self: self.data, doc="read-only data access method")
# index is not defined, because the `owner` attribute must necessarily be None
@property
def value(self):
return self.data
def walk(
......
......@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Node, Variable, applys_between
from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
......@@ -101,7 +101,9 @@ class FunctionGraph(MetaObject):
raise ValueError("No outputs specified")
if inputs is None:
inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)]
inputs = [
i for i in graph_inputs(outputs) if not isinstance(i, AtomicVariable)
]
if clone:
_memo = clone_get_equiv(
......@@ -306,7 +308,7 @@ class FunctionGraph(MetaObject):
self.import_node(var.owner, reason=reason, import_missing=import_missing)
elif (
var.owner is None
and not isinstance(var, Constant)
and not isinstance(var, AtomicVariable)
and var not in self.inputs
):
from aesara.graph.null_type import NullType
......@@ -354,7 +356,7 @@ class FunctionGraph(MetaObject):
for var in node.inputs:
if (
var.owner is None
and not isinstance(var, Constant)
and not isinstance(var, AtomicVariable)
and var not in self.inputs
):
if import_missing:
......@@ -515,9 +517,6 @@ class FunctionGraph(MetaObject):
)
for node, i in list(self.clients[var]):
assert (node == "output" and self.outputs[i] is var) or (
isinstance(node, Apply) and node.inputs[i] is var
)
self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing
)
......@@ -839,7 +838,7 @@ class FunctionGraph(MetaObject):
if (
variable.owner is None
and variable not in self.inputs
and not isinstance(variable, Constant)
and not isinstance(variable, AtomicVariable)
):
raise Exception(f"Undeclared input: {variable}")
for cl_node, i in self.clients[variable]:
......
......@@ -19,13 +19,12 @@ from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import TypeAlias
import aesara
from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh
from aesara.graph.basic import (
Apply,
AtomicVariable,
Constant,
Variable,
applys_between,
......@@ -500,12 +499,9 @@ class MergeFeature(Feature):
assert not hasattr(fgraph, "merge_feature")
fgraph.merge_feature = self
# For constants
self.seen_constants = set()
# variable -> signature (for constants)
self.const_sig = AssocList()
# signature -> variable (for constants)
self.const_sig_inv = AssocList()
self.seen_atomics = set()
self.atomic_sig = AssocList()
self.atomic_sig_inv = AssocList()
# For all Apply nodes
# Set of distinct (not mergeable) nodes
......@@ -539,17 +535,13 @@ class MergeFeature(Feature):
self.nodes_seen.discard(node)
self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs.
if not isinstance(node, str):
assert node.inputs
if isinstance(new_r, Constant):
self.process_constant(fgraph, new_r)
if isinstance(new_r, AtomicVariable):
self.process_atomic(fgraph, new_r)
def on_import(self, fgraph, node, reason):
for c in node.inputs:
if isinstance(c, Constant):
self.process_constant(fgraph, c)
if isinstance(c, AtomicVariable):
self.process_atomic(fgraph, c)
self.process_node(fgraph, node)
......@@ -558,19 +550,19 @@ class MergeFeature(Feature):
if not node.inputs:
self.noinput_nodes.discard(node)
for c in node.inputs:
if isinstance(c, Constant) and (len(fgraph.clients[c]) <= 1):
if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1:
# This was the last node using this constant
sig = self.const_sig[c]
self.const_sig.discard(c)
self.const_sig_inv.discard(sig)
self.seen_constants.discard(id(c))
def process_constant(self, fgraph, c):
"""Check if a constant `c` can be merged, and queue that replacement."""
if id(c) in self.seen_constants:
sig = self.atomic_sig[c]
self.atomic_sig.discard(c)
self.atomic_sig_inv.discard(sig)
self.seen_atomics.discard(id(c))
def process_atomic(self, fgraph, c):
"""Check if an atomic `c` can be merged, and queue that replacement."""
if id(c) in self.seen_atomics:
return
sig = c.merge_signature()
other_c = self.const_sig_inv.get(sig, None)
other_c = self.atomic_sig_inv.get(sig, None)
if other_c is not None:
# multiple names will clobber each other..
# we adopt convention to keep the last name
......@@ -579,9 +571,9 @@ class MergeFeature(Feature):
self.scheduled.append([[(c, other_c, "merge")]])
else:
# this is a new constant
self.const_sig[c] = sig
self.const_sig_inv[sig] = c
self.seen_constants.add(id(c))
self.atomic_sig[c] = sig
self.atomic_sig_inv[sig] = c
self.seen_atomics.add(id(c))
def process_node(self, fgraph, node):
r"""Check if a `node` can be merged, and queue that replacement.
......@@ -602,6 +594,7 @@ class MergeFeature(Feature):
# using `node.inputs[0]` will make us look at more nodes on
# average, so by picking the smallest clients list, we might speed
# things up?
clients = sorted(
(fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x)
)[0]
......@@ -616,6 +609,7 @@ class MergeFeature(Feature):
replacement_candidates = []
for candidate in merge_candidates:
if candidate is node:
continue
if len(node.inputs) != len(candidate.inputs):
......@@ -658,9 +652,10 @@ class MergeOptimizer(GlobalOptimizer):
one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input-to-output order throughout the graph.
The first step of merging is constant-merging, so that all clients of an
``int(1)`` for example, are transferred to just one particular instance of
``int(1)``.
The first step of merging is atomic variable-merging, so that all clients of a
:class:`Constant` like ``int(1)``, are transferred to just one particular
instance of ``int(1)``. :class:`NominalVariable`\s are not merged individually
like this; only the nodes that use them are.
"""
......@@ -678,7 +673,7 @@ class MergeOptimizer(GlobalOptimizer):
callbacks_before = fgraph.execute_callbacks_times.copy()
nb_merged = 0
nb_constant = 0
nb_atomic = 0
while sched:
pairs_list = sched.pop()
success = True
......@@ -739,8 +734,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs = [(pairs[0][1], pairs[0][0])]
try:
# If they're all `Constant`s, there's no need to call validate.
if all(isinstance(old, Constant) for old, _ in pairs):
# If they're all `AtomicVariable`s, there's no need to call validate.
if all(isinstance(old, AtomicVariable) for old, _ in pairs):
fgraph.replace_all(pairs, reason="MergeOptimizer")
else:
fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
......@@ -753,8 +748,8 @@ class MergeOptimizer(GlobalOptimizer):
if success:
nb_merged += len(pairs)
if isinstance(pairs[0][0], Constant):
nb_constant += 1
if isinstance(pairs[0][0], AtomicVariable):
nb_atomic += 1
break
if fgraph.profile:
......@@ -782,7 +777,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time,
callbacks_time,
nb_merged,
nb_constant,
nb_atomic,
)
def __str__(self):
......@@ -798,14 +793,14 @@ class MergeOptimizer(GlobalOptimizer):
callback_time,
callbacks_time,
nb_merged,
nb_constant,
nb_atomic,
) = prof
blanc = " " * level
print(blanc, "MergeOptimizer", file=stream)
print(
blanc,
f" nb fail={nb_fail:5d} merged={nb_merged:5d} constant={nb_constant:5d}",
f" nb fail={nb_fail:5d} merged={nb_merged:5d} atomic={nb_atomic:5d}",
file=stream,
)
print(
......@@ -836,7 +831,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time = merge_none_number(prof1[3], prof2[3])
callbacks_time = merge_dict(prof1[4], prof2[4])
nb_merged = prof1[5] + prof2[5]
nb_constant = prof1[6] + prof2[6]
nb_atomic = prof1[6] + prof2[6]
return (
nb_fail,
replace_time,
......@@ -844,7 +839,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time,
callbacks_time,
nb_merged,
nb_constant,
nb_atomic,
)
......@@ -1127,7 +1122,7 @@ class LocalOptTracker:
def __init__(self):
self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {}
self.tracked_types: Dict[TypeAlias, List[LocalOptimizer]] = {}
self.tracked_types: Dict[type, List[LocalOptimizer]] = {}
self.untracked_opts: List[LocalOptimizer] = []
def add_tracker(self, rw: LocalOptimizer):
......
......@@ -14,7 +14,13 @@ import numpy as np
from aesara.compile.compilelock import lock_ctx
from aesara.configdefaults import config
from aesara.graph.basic import Constant, NoParams, io_toposort, vars_between
from aesara.graph.basic import (
AtomicVariable,
Constant,
NoParams,
io_toposort,
vars_between,
)
from aesara.graph.callcache import CallCache
from aesara.link.basic import Container, Linker, LocalLinker, PerformLinker
from aesara.link.c.cmodule import (
......@@ -641,7 +647,7 @@ class CLinker(Linker):
self.orphans = list(
r
for r in self.variables
if isinstance(r, Constant) and r not in self.inputs
if isinstance(r, AtomicVariable) and r not in self.inputs
)
# C type constants (aesara.scalar.ScalarType). They don't request an object
self.consts = []
......@@ -730,7 +736,7 @@ class CLinker(Linker):
[get_c_declare, get_c_extract, get_c_cleanup],
]
elif variable in self.orphans:
if not isinstance(variable, Constant):
if not isinstance(variable, AtomicVariable):
raise TypeError(
"All orphans to CLinker must be Constant instances. "
f"Got {variable}"
......@@ -1404,7 +1410,7 @@ class CLinker(Linker):
# It is important that a variable (i)
# yield a 'position' that reflects its role in code_gen()
if isinstance(i, Constant): # orphans
if isinstance(i, AtomicVariable): # orphans
if id(i) not in constant_ids:
isig = (i.signature(), topological_pos, i_idx)
# If the Aesara constant provides a strong hash
......@@ -1634,7 +1640,10 @@ class CLinker(Linker):
]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
if storage_map is None:
orphd = [[orphan.data] for orphan in self.orphans]
orphd = [
[orphan.data] if isinstance(orphan, Constant) else []
for orphan in self.orphans
]
else:
orphd = [storage_map[orphan] for orphan in self.orphans]
......
......@@ -8,6 +8,7 @@ from aesara import config, function, shared
from aesara import tensor as at
from aesara.graph.basic import (
Apply,
NominalVariable,
Variable,
ancestors,
applys_between,
......@@ -52,6 +53,9 @@ class MyType(Type):
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
def __hash__(self):
return hash((type(self), self.thingy))
def __str__(self):
return f"R{self.thingy}"
......@@ -694,3 +698,76 @@ def test_clone_new_inputs():
assert z_node_new.outputs[0].type.shape == (1,)
assert z_node_new.inputs[0].type.shape == (1,)
assert z_node_new.inputs[1].type.shape == (1,)
def test_NominalVariable():
type1 = MyType(1)
nv1 = NominalVariable(1, type1)
nv2 = NominalVariable(1, type1)
assert nv1 is nv2
assert nv1.equals(nv2)
assert hash(nv1) == hash(nv2)
type2 = MyType(2)
nv3 = NominalVariable(1, type2)
assert not nv1.equals(nv3)
assert hash(nv1) != hash(nv3)
type3 = MyType(1)
assert type3 == type1
nv4 = NominalVariable(1, type3)
assert nv1 is nv4
assert nv1.equals(nv4)
assert hash(nv1) == hash(nv4)
nv5 = NominalVariable(2, type3)
assert not nv4.equals(nv5)
assert hash(nv4) != hash(nv5)
assert repr(nv5) == f"NominalVariable(2, {repr(type3)})"
assert nv5.signature() == (type3, 2)
nv5_pkld = pickle.dumps(nv5)
nv5_unpkld = pickle.loads(nv5_pkld)
assert type(nv5_unpkld) is type(nv5)
assert nv5_unpkld.equals(nv5)
assert nv5_unpkld is nv5
nv5_clone = nv5.clone()
assert type(nv5_clone) is type(nv5)
assert nv5_clone.equals(nv5)
assert nv5_clone is nv5
def test_NominalVariable_create_variable_type():
ttype = TensorType("float64", (None, None))
ntv = NominalVariable(0, ttype)
assert isinstance(ntv, TensorVariable)
assert isinstance(ntv, NominalVariable)
assert ntv.ndim == 2
assert ntv.broadcastable == (False, False)
assert ntv.dtype == "float64"
ntv2 = NominalVariable(0, ttype)
assert type(ntv2) is type(ntv)
assert ntv2.equals(ntv)
assert ntv2 is ntv
ntv_pkld = pickle.dumps(ntv)
ntv_unpkld = pickle.loads(ntv_pkld)
assert type(ntv_unpkld) is type(ntv)
assert ntv_unpkld.equals(ntv)
assert ntv_unpkld is ntv
......@@ -4,9 +4,19 @@ import numpy as np
import pytest
from aesara.configdefaults import config
from aesara.graph.basic import NominalVariable
from aesara.graph.fg import FunctionGraph
from aesara.graph.utils import MissingInputError
from tests.graph.utils import MyConstant, MyOp, MyVariable, MyVariable2, op1, op2, op3
from tests.graph.utils import (
MyConstant,
MyOp,
MyType,
MyVariable,
MyVariable2,
op1,
op2,
op3,
)
class TestFunctionGraph:
......@@ -683,3 +693,18 @@ class TestFunctionGraph:
assert not fg.variables
assert not fg.apply_nodes
assert fg.clients == {var1: [], var2: []}
def test_nominals(self):
t1 = MyType()
nm = NominalVariable(1, t1)
nm2 = NominalVariable(2, t1)
v1 = op1(nm, nm2)
fg = FunctionGraph(outputs=[v1], clone=False)
assert nm not in fg.inputs
assert nm2 not in fg.inputs
assert nm in fg.variables
assert nm2 in fg.variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论