提交 980ecacf authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add AtomicVariable and NominalVariable classes

上级 96366f99
"""Core graph classes.""" """Core graph classes."""
import abc
import warnings import warnings
from collections import deque from collections import deque
from copy import copy from copy import copy
...@@ -10,6 +11,7 @@ from typing import ( ...@@ -10,6 +11,7 @@ from typing import (
Deque, Deque,
Dict, Dict,
Generator, Generator,
Hashable,
Iterable, Iterable,
Iterator, Iterator,
List, List,
...@@ -396,6 +398,14 @@ class Variable(Node): ...@@ -396,6 +398,14 @@ class Variable(Node):
def owner(self, value) -> None: def owner(self, value) -> None:
self._owner = value self._owner = value
@property
def index(self):
return self._index
@index.setter
def index(self, value):
self._index = value
def __init__( def __init__(
self, self,
type, type,
...@@ -411,7 +421,7 @@ class Variable(Node): ...@@ -411,7 +421,7 @@ class Variable(Node):
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
self._owner = owner self.owner = owner
if index is not None and not isinstance(index, int): if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index) raise TypeError("index must be an int", index)
...@@ -586,7 +596,98 @@ class Variable(Node): ...@@ -586,7 +596,98 @@ class Variable(Node):
return d return d
class Constant(Variable): class AtomicVariable(Variable):
"""A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type, **kwargs):
super().__init__(type, **kwargs)
@abc.abstractmethod
def signature(self):
...
def merge_signature(self):
return self.signature()
def equals(self, other):
"""
This does what `__eq__` would normally do, but `Variable` and `Apply`
should always be hashable by `id`.
"""
return isinstance(other, type(self)) and self.signature() == other.signature()
@property
def owner(self):
return None
@owner.setter
def owner(self, value):
if value is not None:
raise ValueError("AtomicVariable instances cannot have an owner.")
@property
def index(self):
return None
@index.setter
def index(self, value):
if value is not None:
raise ValueError("AtomicVariable instances cannot have an index.")
class NominalVariable(AtomicVariable):
"""A variable that enables alpha-equivalent comparisons."""
__instances__: Dict[Hashable, type] = {}
def __new__(cls, id, typ, **kwargs):
if (id, typ) not in cls.__instances__:
var_type = typ.variable_type
type_name = f"Nominal{var_type.__name__}"
def _reduce(self):
return cls, (self.id, self.type)
def _str(self):
return f"*{self.id}-{var_type.__str__(self)}"
new_type = type(
type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str}
)
res = super().__new__(new_type)
cls.__instances__[(id, typ)] = res
return cls.__instances__[(id, typ)]
def __init__(self, id, typ, **kwargs):
self.id = id
super().__init__(typ, **kwargs)
def clone(self):
return self
def __eq__(self, other):
if self is other:
return True
return (
type(self) == type(other)
and self.id == other.id
and self.type == other.type
)
def __hash__(self):
return hash((type(self), self.id, self.type))
def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def signature(self):
return (self.type, self.id)
class Constant(AtomicVariable):
"""A `Variable` with a fixed `data` field. """A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant `Constant` nodes make numerous optimizations possible (e.g. constant
...@@ -602,23 +703,16 @@ class Constant(Variable): ...@@ -602,23 +703,16 @@ class Constant(Variable):
# __slots__ = ['data'] # __slots__ = ['data']
def __init__(self, type, data, name=None): def __init__(self, type, data, name=None):
super().__init__(type, None, None, name) super().__init__(type, name=name)
self.data = type.filter(data) self.data = type.filter(data)
add_tag_trace(self) add_tag_trace(self)
def get_test_value(self): def get_test_value(self):
return self.data return self.data
def equals(self, other):
# this does what __eq__ should do, but Variable and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
def signature(self): def signature(self):
return (self.type, self.data) return (self.type, self.data)
def merge_signature(self):
return self.signature()
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -641,9 +735,9 @@ class Constant(Variable): ...@@ -641,9 +735,9 @@ class Constant(Variable):
if value is not None: if value is not None:
raise ValueError("Constant instances cannot have an owner.") raise ValueError("Constant instances cannot have an owner.")
value = property(lambda self: self.data, doc="read-only data access method") @property
def value(self):
# index is not defined, because the `owner` attribute must necessarily be None return self.data
def walk( def walk(
......
...@@ -18,7 +18,7 @@ from typing_extensions import Literal ...@@ -18,7 +18,7 @@ from typing_extensions import Literal
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Node, Variable, applys_between from aesara.graph.basic import Apply, AtomicVariable, Node, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
...@@ -101,7 +101,9 @@ class FunctionGraph(MetaObject): ...@@ -101,7 +101,9 @@ class FunctionGraph(MetaObject):
raise ValueError("No outputs specified") raise ValueError("No outputs specified")
if inputs is None: if inputs is None:
inputs = [i for i in graph_inputs(outputs) if not isinstance(i, Constant)] inputs = [
i for i in graph_inputs(outputs) if not isinstance(i, AtomicVariable)
]
if clone: if clone:
_memo = clone_get_equiv( _memo = clone_get_equiv(
...@@ -306,7 +308,7 @@ class FunctionGraph(MetaObject): ...@@ -306,7 +308,7 @@ class FunctionGraph(MetaObject):
self.import_node(var.owner, reason=reason, import_missing=import_missing) self.import_node(var.owner, reason=reason, import_missing=import_missing)
elif ( elif (
var.owner is None var.owner is None
and not isinstance(var, Constant) and not isinstance(var, AtomicVariable)
and var not in self.inputs and var not in self.inputs
): ):
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
...@@ -354,7 +356,7 @@ class FunctionGraph(MetaObject): ...@@ -354,7 +356,7 @@ class FunctionGraph(MetaObject):
for var in node.inputs: for var in node.inputs:
if ( if (
var.owner is None var.owner is None
and not isinstance(var, Constant) and not isinstance(var, AtomicVariable)
and var not in self.inputs and var not in self.inputs
): ):
if import_missing: if import_missing:
...@@ -515,9 +517,6 @@ class FunctionGraph(MetaObject): ...@@ -515,9 +517,6 @@ class FunctionGraph(MetaObject):
) )
for node, i in list(self.clients[var]): for node, i in list(self.clients[var]):
assert (node == "output" and self.outputs[i] is var) or (
isinstance(node, Apply) and node.inputs[i] is var
)
self.change_node_input( self.change_node_input(
node, i, new_var, reason=reason, import_missing=import_missing node, i, new_var, reason=reason, import_missing=import_missing
) )
...@@ -839,7 +838,7 @@ class FunctionGraph(MetaObject): ...@@ -839,7 +838,7 @@ class FunctionGraph(MetaObject):
if ( if (
variable.owner is None variable.owner is None
and variable not in self.inputs and variable not in self.inputs
and not isinstance(variable, Constant) and not isinstance(variable, AtomicVariable)
): ):
raise Exception(f"Undeclared input: {variable}") raise Exception(f"Undeclared input: {variable}")
for cl_node, i in self.clients[variable]: for cl_node, i in self.clients[variable]:
......
...@@ -19,13 +19,12 @@ from functools import _compose_mro, partial, reduce # type: ignore ...@@ -19,13 +19,12 @@ from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing_extensions import TypeAlias
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import destroyhandler as dh from aesara.graph import destroyhandler as dh
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
AtomicVariable,
Constant, Constant,
Variable, Variable,
applys_between, applys_between,
...@@ -500,12 +499,9 @@ class MergeFeature(Feature): ...@@ -500,12 +499,9 @@ class MergeFeature(Feature):
assert not hasattr(fgraph, "merge_feature") assert not hasattr(fgraph, "merge_feature")
fgraph.merge_feature = self fgraph.merge_feature = self
# For constants self.seen_atomics = set()
self.seen_constants = set() self.atomic_sig = AssocList()
# variable -> signature (for constants) self.atomic_sig_inv = AssocList()
self.const_sig = AssocList()
# signature -> variable (for constants)
self.const_sig_inv = AssocList()
# For all Apply nodes # For all Apply nodes
# Set of distinct (not mergeable) nodes # Set of distinct (not mergeable) nodes
...@@ -539,17 +535,13 @@ class MergeFeature(Feature): ...@@ -539,17 +535,13 @@ class MergeFeature(Feature):
self.nodes_seen.discard(node) self.nodes_seen.discard(node)
self.process_node(fgraph, node) self.process_node(fgraph, node)
# Since we are in on_change_input, node should have inputs. if isinstance(new_r, AtomicVariable):
if not isinstance(node, str): self.process_atomic(fgraph, new_r)
assert node.inputs
if isinstance(new_r, Constant):
self.process_constant(fgraph, new_r)
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
for c in node.inputs: for c in node.inputs:
if isinstance(c, Constant): if isinstance(c, AtomicVariable):
self.process_constant(fgraph, c) self.process_atomic(fgraph, c)
self.process_node(fgraph, node) self.process_node(fgraph, node)
...@@ -558,19 +550,19 @@ class MergeFeature(Feature): ...@@ -558,19 +550,19 @@ class MergeFeature(Feature):
if not node.inputs: if not node.inputs:
self.noinput_nodes.discard(node) self.noinput_nodes.discard(node)
for c in node.inputs: for c in node.inputs:
if isinstance(c, Constant) and (len(fgraph.clients[c]) <= 1): if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1:
# This was the last node using this constant # This was the last node using this constant
sig = self.const_sig[c] sig = self.atomic_sig[c]
self.const_sig.discard(c) self.atomic_sig.discard(c)
self.const_sig_inv.discard(sig) self.atomic_sig_inv.discard(sig)
self.seen_constants.discard(id(c)) self.seen_atomics.discard(id(c))
def process_constant(self, fgraph, c): def process_atomic(self, fgraph, c):
"""Check if a constant `c` can be merged, and queue that replacement.""" """Check if an atomic `c` can be merged, and queue that replacement."""
if id(c) in self.seen_constants: if id(c) in self.seen_atomics:
return return
sig = c.merge_signature() sig = c.merge_signature()
other_c = self.const_sig_inv.get(sig, None) other_c = self.atomic_sig_inv.get(sig, None)
if other_c is not None: if other_c is not None:
# multiple names will clobber each other.. # multiple names will clobber each other..
# we adopt convention to keep the last name # we adopt convention to keep the last name
...@@ -579,9 +571,9 @@ class MergeFeature(Feature): ...@@ -579,9 +571,9 @@ class MergeFeature(Feature):
self.scheduled.append([[(c, other_c, "merge")]]) self.scheduled.append([[(c, other_c, "merge")]])
else: else:
# this is a new constant # this is a new constant
self.const_sig[c] = sig self.atomic_sig[c] = sig
self.const_sig_inv[sig] = c self.atomic_sig_inv[sig] = c
self.seen_constants.add(id(c)) self.seen_atomics.add(id(c))
def process_node(self, fgraph, node): def process_node(self, fgraph, node):
r"""Check if a `node` can be merged, and queue that replacement. r"""Check if a `node` can be merged, and queue that replacement.
...@@ -602,6 +594,7 @@ class MergeFeature(Feature): ...@@ -602,6 +594,7 @@ class MergeFeature(Feature):
# using `node.inputs[0]` will make us look at more nodes on # using `node.inputs[0]` will make us look at more nodes on
# average, so by picking the smallest clients list, we might speed # average, so by picking the smallest clients list, we might speed
# things up? # things up?
clients = sorted( clients = sorted(
(fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x) (fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x)
)[0] )[0]
...@@ -616,6 +609,7 @@ class MergeFeature(Feature): ...@@ -616,6 +609,7 @@ class MergeFeature(Feature):
replacement_candidates = [] replacement_candidates = []
for candidate in merge_candidates: for candidate in merge_candidates:
if candidate is node: if candidate is node:
continue continue
if len(node.inputs) != len(candidate.inputs): if len(node.inputs) != len(candidate.inputs):
...@@ -658,9 +652,10 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -658,9 +652,10 @@ class MergeOptimizer(GlobalOptimizer):
one are transferred to the other and one of them is removed from the graph. one are transferred to the other and one of them is removed from the graph.
This procedure is carried out in input-to-output order throughout the graph. This procedure is carried out in input-to-output order throughout the graph.
The first step of merging is constant-merging, so that all clients of an The first step of merging is atomic variable-merging, so that all clients of a
``int(1)`` for example, are transferred to just one particular instance of :class:`Constant` like ``int(1)``, are transferred to just one particular
``int(1)``. instance of ``int(1)``. :class:`NominalVariable`\s are not merged individually
like this; only the nodes that use them are.
""" """
...@@ -678,7 +673,7 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -678,7 +673,7 @@ class MergeOptimizer(GlobalOptimizer):
callbacks_before = fgraph.execute_callbacks_times.copy() callbacks_before = fgraph.execute_callbacks_times.copy()
nb_merged = 0 nb_merged = 0
nb_constant = 0 nb_atomic = 0
while sched: while sched:
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
...@@ -739,8 +734,8 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -739,8 +734,8 @@ class MergeOptimizer(GlobalOptimizer):
pairs = [(pairs[0][1], pairs[0][0])] pairs = [(pairs[0][1], pairs[0][0])]
try: try:
# If they're all `Constant`s, there's no need to call validate. # If they're all `AtomicVariable`s, there's no need to call validate.
if all(isinstance(old, Constant) for old, _ in pairs): if all(isinstance(old, AtomicVariable) for old, _ in pairs):
fgraph.replace_all(pairs, reason="MergeOptimizer") fgraph.replace_all(pairs, reason="MergeOptimizer")
else: else:
fgraph.replace_all_validate(pairs, reason="MergeOptimizer") fgraph.replace_all_validate(pairs, reason="MergeOptimizer")
...@@ -753,8 +748,8 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -753,8 +748,8 @@ class MergeOptimizer(GlobalOptimizer):
if success: if success:
nb_merged += len(pairs) nb_merged += len(pairs)
if isinstance(pairs[0][0], Constant): if isinstance(pairs[0][0], AtomicVariable):
nb_constant += 1 nb_atomic += 1
break break
if fgraph.profile: if fgraph.profile:
...@@ -782,7 +777,7 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -782,7 +777,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time, callback_time,
callbacks_time, callbacks_time,
nb_merged, nb_merged,
nb_constant, nb_atomic,
) )
def __str__(self): def __str__(self):
...@@ -798,14 +793,14 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -798,14 +793,14 @@ class MergeOptimizer(GlobalOptimizer):
callback_time, callback_time,
callbacks_time, callbacks_time,
nb_merged, nb_merged,
nb_constant, nb_atomic,
) = prof ) = prof
blanc = " " * level blanc = " " * level
print(blanc, "MergeOptimizer", file=stream) print(blanc, "MergeOptimizer", file=stream)
print( print(
blanc, blanc,
f" nb fail={nb_fail:5d} merged={nb_merged:5d} constant={nb_constant:5d}", f" nb fail={nb_fail:5d} merged={nb_merged:5d} atomic={nb_atomic:5d}",
file=stream, file=stream,
) )
print( print(
...@@ -836,7 +831,7 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -836,7 +831,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time = merge_none_number(prof1[3], prof2[3]) callback_time = merge_none_number(prof1[3], prof2[3])
callbacks_time = merge_dict(prof1[4], prof2[4]) callbacks_time = merge_dict(prof1[4], prof2[4])
nb_merged = prof1[5] + prof2[5] nb_merged = prof1[5] + prof2[5]
nb_constant = prof1[6] + prof2[6] nb_atomic = prof1[6] + prof2[6]
return ( return (
nb_fail, nb_fail,
replace_time, replace_time,
...@@ -844,7 +839,7 @@ class MergeOptimizer(GlobalOptimizer): ...@@ -844,7 +839,7 @@ class MergeOptimizer(GlobalOptimizer):
callback_time, callback_time,
callbacks_time, callbacks_time,
nb_merged, nb_merged,
nb_constant, nb_atomic,
) )
...@@ -1127,7 +1122,7 @@ class LocalOptTracker: ...@@ -1127,7 +1122,7 @@ class LocalOptTracker:
def __init__(self): def __init__(self):
self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {} self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {}
self.tracked_types: Dict[TypeAlias, List[LocalOptimizer]] = {} self.tracked_types: Dict[type, List[LocalOptimizer]] = {}
self.untracked_opts: List[LocalOptimizer] = [] self.untracked_opts: List[LocalOptimizer] = []
def add_tracker(self, rw: LocalOptimizer): def add_tracker(self, rw: LocalOptimizer):
......
...@@ -14,7 +14,13 @@ import numpy as np ...@@ -14,7 +14,13 @@ import numpy as np
from aesara.compile.compilelock import lock_ctx from aesara.compile.compilelock import lock_ctx
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, NoParams, io_toposort, vars_between from aesara.graph.basic import (
AtomicVariable,
Constant,
NoParams,
io_toposort,
vars_between,
)
from aesara.graph.callcache import CallCache from aesara.graph.callcache import CallCache
from aesara.link.basic import Container, Linker, LocalLinker, PerformLinker from aesara.link.basic import Container, Linker, LocalLinker, PerformLinker
from aesara.link.c.cmodule import ( from aesara.link.c.cmodule import (
...@@ -641,7 +647,7 @@ class CLinker(Linker): ...@@ -641,7 +647,7 @@ class CLinker(Linker):
self.orphans = list( self.orphans = list(
r r
for r in self.variables for r in self.variables
if isinstance(r, Constant) and r not in self.inputs if isinstance(r, AtomicVariable) and r not in self.inputs
) )
# C type constants (aesara.scalar.ScalarType). They don't request an object # C type constants (aesara.scalar.ScalarType). They don't request an object
self.consts = [] self.consts = []
...@@ -730,7 +736,7 @@ class CLinker(Linker): ...@@ -730,7 +736,7 @@ class CLinker(Linker):
[get_c_declare, get_c_extract, get_c_cleanup], [get_c_declare, get_c_extract, get_c_cleanup],
] ]
elif variable in self.orphans: elif variable in self.orphans:
if not isinstance(variable, Constant): if not isinstance(variable, AtomicVariable):
raise TypeError( raise TypeError(
"All orphans to CLinker must be Constant instances. " "All orphans to CLinker must be Constant instances. "
f"Got {variable}" f"Got {variable}"
...@@ -1404,7 +1410,7 @@ class CLinker(Linker): ...@@ -1404,7 +1410,7 @@ class CLinker(Linker):
# It is important that a variable (i) # It is important that a variable (i)
# yield a 'position' that reflects its role in code_gen() # yield a 'position' that reflects its role in code_gen()
if isinstance(i, Constant): # orphans if isinstance(i, AtomicVariable): # orphans
if id(i) not in constant_ids: if id(i) not in constant_ids:
isig = (i.signature(), topological_pos, i_idx) isig = (i.signature(), topological_pos, i_idx)
# If the Aesara constant provides a strong hash # If the Aesara constant provides a strong hash
...@@ -1634,7 +1640,10 @@ class CLinker(Linker): ...@@ -1634,7 +1640,10 @@ class CLinker(Linker):
] ]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
if storage_map is None: if storage_map is None:
orphd = [[orphan.data] for orphan in self.orphans] orphd = [
[orphan.data] if isinstance(orphan, Constant) else []
for orphan in self.orphans
]
else: else:
orphd = [storage_map[orphan] for orphan in self.orphans] orphd = [storage_map[orphan] for orphan in self.orphans]
......
...@@ -8,6 +8,7 @@ from aesara import config, function, shared ...@@ -8,6 +8,7 @@ from aesara import config, function, shared
from aesara import tensor as at from aesara import tensor as at
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
NominalVariable,
Variable, Variable,
ancestors, ancestors,
applys_between, applys_between,
...@@ -52,6 +53,9 @@ class MyType(Type): ...@@ -52,6 +53,9 @@ class MyType(Type):
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy return isinstance(other, MyType) and other.thingy == self.thingy
def __hash__(self):
return hash((type(self), self.thingy))
def __str__(self): def __str__(self):
return f"R{self.thingy}" return f"R{self.thingy}"
...@@ -694,3 +698,76 @@ def test_clone_new_inputs(): ...@@ -694,3 +698,76 @@ def test_clone_new_inputs():
assert z_node_new.outputs[0].type.shape == (1,) assert z_node_new.outputs[0].type.shape == (1,)
assert z_node_new.inputs[0].type.shape == (1,) assert z_node_new.inputs[0].type.shape == (1,)
assert z_node_new.inputs[1].type.shape == (1,) assert z_node_new.inputs[1].type.shape == (1,)
def test_NominalVariable():
type1 = MyType(1)
nv1 = NominalVariable(1, type1)
nv2 = NominalVariable(1, type1)
assert nv1 is nv2
assert nv1.equals(nv2)
assert hash(nv1) == hash(nv2)
type2 = MyType(2)
nv3 = NominalVariable(1, type2)
assert not nv1.equals(nv3)
assert hash(nv1) != hash(nv3)
type3 = MyType(1)
assert type3 == type1
nv4 = NominalVariable(1, type3)
assert nv1 is nv4
assert nv1.equals(nv4)
assert hash(nv1) == hash(nv4)
nv5 = NominalVariable(2, type3)
assert not nv4.equals(nv5)
assert hash(nv4) != hash(nv5)
assert repr(nv5) == f"NominalVariable(2, {repr(type3)})"
assert nv5.signature() == (type3, 2)
nv5_pkld = pickle.dumps(nv5)
nv5_unpkld = pickle.loads(nv5_pkld)
assert type(nv5_unpkld) is type(nv5)
assert nv5_unpkld.equals(nv5)
assert nv5_unpkld is nv5
nv5_clone = nv5.clone()
assert type(nv5_clone) is type(nv5)
assert nv5_clone.equals(nv5)
assert nv5_clone is nv5
def test_NominalVariable_create_variable_type():
ttype = TensorType("float64", (None, None))
ntv = NominalVariable(0, ttype)
assert isinstance(ntv, TensorVariable)
assert isinstance(ntv, NominalVariable)
assert ntv.ndim == 2
assert ntv.broadcastable == (False, False)
assert ntv.dtype == "float64"
ntv2 = NominalVariable(0, ttype)
assert type(ntv2) is type(ntv)
assert ntv2.equals(ntv)
assert ntv2 is ntv
ntv_pkld = pickle.dumps(ntv)
ntv_unpkld = pickle.loads(ntv_pkld)
assert type(ntv_unpkld) is type(ntv)
assert ntv_unpkld.equals(ntv)
assert ntv_unpkld is ntv
...@@ -4,9 +4,19 @@ import numpy as np ...@@ -4,9 +4,19 @@ import numpy as np
import pytest import pytest
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import NominalVariable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from tests.graph.utils import MyConstant, MyOp, MyVariable, MyVariable2, op1, op2, op3 from tests.graph.utils import (
MyConstant,
MyOp,
MyType,
MyVariable,
MyVariable2,
op1,
op2,
op3,
)
class TestFunctionGraph: class TestFunctionGraph:
...@@ -683,3 +693,18 @@ class TestFunctionGraph: ...@@ -683,3 +693,18 @@ class TestFunctionGraph:
assert not fg.variables assert not fg.variables
assert not fg.apply_nodes assert not fg.apply_nodes
assert fg.clients == {var1: [], var2: []} assert fg.clients == {var1: [], var2: []}
def test_nominals(self):
t1 = MyType()
nm = NominalVariable(1, t1)
nm2 = NominalVariable(2, t1)
v1 = op1(nm, nm2)
fg = FunctionGraph(outputs=[v1], clone=False)
assert nm not in fg.inputs
assert nm2 not in fg.inputs
assert nm in fg.variables
assert nm2 in fg.variables
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论