提交 01d20497 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Variable type by Type and Apply

上级 de9ad202
...@@ -6,6 +6,7 @@ from copy import copy ...@@ -6,6 +6,7 @@ from copy import copy
from itertools import count from itertools import count
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Callable, Callable,
Collection, Collection,
Deque, Deque,
...@@ -47,6 +48,9 @@ if TYPE_CHECKING: ...@@ -47,6 +48,9 @@ if TYPE_CHECKING:
OpType = TypeVar("OpType", bound="Op") OpType = TypeVar("OpType", bound="Op")
OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True)
_TypeType = TypeVar("_TypeType", bound="Type")
_IdType = TypeVar("_IdType", bound=Hashable)
T = TypeVar("T", bound="Node") T = TypeVar("T", bound="Node")
NoParams = object() NoParams = object()
...@@ -61,7 +65,6 @@ class Node(MetaObject): ...@@ -61,7 +65,6 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`. keeps track of its parents via `Variable.owner` / `Apply.inputs`.
""" """
type: "Type"
name: Optional[str] name: Optional[str]
def get_parents(self): def get_parents(self):
...@@ -110,7 +113,10 @@ class Apply(Node, Generic[OpType]): ...@@ -110,7 +113,10 @@ class Apply(Node, Generic[OpType]):
""" """
def __init__( def __init__(
self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"] self,
op: OpType,
inputs: Sequence["Variable"],
outputs: Sequence["Variable"],
): ):
if not isinstance(inputs, Sequence): if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a sequence type") raise TypeError("The inputs of an Apply must be a sequence type")
...@@ -309,7 +315,7 @@ class Apply(Node, Generic[OpType]): ...@@ -309,7 +315,7 @@ class Apply(Node, Generic[OpType]):
return self.op.params_type return self.op.params_type
class Variable(Node): class Variable(Node, Generic[_TypeType, OptionalApplyType]):
r""" r"""
A :term:`Variable` is a node in an expression graph that represents a A :term:`Variable` is a node in an expression graph that represents a
variable. variable.
...@@ -407,10 +413,10 @@ class Variable(Node): ...@@ -407,10 +413,10 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name'] # __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0) __count__ = count(0)
_owner: Optional[Apply] _owner: OptionalApplyType
@property @property
def owner(self) -> Optional[Apply]: def owner(self) -> OptionalApplyType:
return self._owner return self._owner
@owner.setter @owner.setter
...@@ -427,30 +433,31 @@ class Variable(Node): ...@@ -427,30 +433,31 @@ class Variable(Node):
def __init__( def __init__(
self, self,
type, type: _TypeType,
owner: Optional[Apply] = None, owner: OptionalApplyType,
index: Optional[int] = None, index: Optional[int] = None,
name: Optional[str] = None, name: Optional[str] = None,
): ) -> None:
super().__init__() super().__init__()
self.tag = ValidatingScratchpad("test_value", type.filter) self.tag = ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
self._owner = owner
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance")
self.owner = owner
if index is not None and not isinstance(index, int): if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index) raise TypeError("index must be an int")
self.index = index self.index = index
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name) raise TypeError("name must be a string")
self.name = name self.name = name
self.auto_name = "auto_" + str(next(self.__count__)) self.auto_name = f"auto_{next(self.__count__)}"
def get_test_value(self): def get_test_value(self):
"""Get the test value. """Get the test value.
...@@ -516,7 +523,6 @@ class Variable(Node): ...@@ -516,7 +523,6 @@ class Variable(Node):
Tags and names are copied to the returned instance. Tags and names are copied to the returned instance.
""" """
# return copy(self)
cp = self.__class__(self.type, None, None, self.name) cp = self.__class__(self.type, None, None, self.name)
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
...@@ -612,11 +618,11 @@ class Variable(Node): ...@@ -612,11 +618,11 @@ class Variable(Node):
return d return d
class AtomicVariable(Variable): class AtomicVariable(Variable[_TypeType, None]):
"""A node type that has no ancestors and should never be considered an input to a graph.""" """A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type, **kwargs): def __init__(self, type: _TypeType, **kwargs):
super().__init__(type, **kwargs) super().__init__(type, None, None, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def signature(self): def signature(self):
...@@ -651,13 +657,13 @@ class AtomicVariable(Variable): ...@@ -651,13 +657,13 @@ class AtomicVariable(Variable):
raise ValueError("AtomicVariable instances cannot have an index.") raise ValueError("AtomicVariable instances cannot have an index.")
class NominalVariable(AtomicVariable): class NominalVariable(AtomicVariable[_TypeType]):
"""A variable that enables alpha-equivalent comparisons.""" """A variable that enables alpha-equivalent comparisons."""
__instances__: Dict[Hashable, type] = {} __instances__: Dict[Tuple["Type", Hashable], "NominalVariable"] = {}
def __new__(cls, id, typ, **kwargs): def __new__(cls, id: _IdType, typ: _TypeType, **kwargs):
if (id, typ) not in cls.__instances__: if (typ, id) not in cls.__instances__:
var_type = typ.variable_type var_type = typ.variable_type
type_name = f"Nominal{var_type.__name__}" type_name = f"Nominal{var_type.__name__}"
...@@ -670,13 +676,13 @@ class NominalVariable(AtomicVariable): ...@@ -670,13 +676,13 @@ class NominalVariable(AtomicVariable):
new_type = type( new_type = type(
type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str} type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str}
) )
res = super().__new__(new_type) res: NominalVariable = super().__new__(new_type)
cls.__instances__[(id, typ)] = res cls.__instances__[(typ, id)] = res
return cls.__instances__[(id, typ)] return cls.__instances__[(typ, id)]
def __init__(self, id, typ, **kwargs): def __init__(self, id: _IdType, typ: _TypeType, **kwargs):
self.id = id self.id = id
super().__init__(typ, **kwargs) super().__init__(typ, **kwargs)
...@@ -699,11 +705,11 @@ class NominalVariable(AtomicVariable): ...@@ -699,11 +705,11 @@ class NominalVariable(AtomicVariable):
def __repr__(self): def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})" return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def signature(self): def signature(self) -> Tuple[_TypeType, _IdType]:
return (self.type, self.id) return (self.type, self.id)
class Constant(AtomicVariable): class Constant(AtomicVariable[_TypeType]):
"""A `Variable` with a fixed `data` field. """A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant `Constant` nodes make numerous optimizations possible (e.g. constant
...@@ -718,7 +724,7 @@ class Constant(AtomicVariable): ...@@ -718,7 +724,7 @@ class Constant(AtomicVariable):
# __slots__ = ['data'] # __slots__ = ['data']
def __init__(self, type, data, name=None): def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None):
super().__init__(type, name=name) super().__init__(type, name=name)
self.data = type.filter(data) self.data = type.filter(data)
add_tag_trace(self) add_tag_trace(self)
......
...@@ -197,7 +197,7 @@ class Type(MetaObject): ...@@ -197,7 +197,7 @@ class Type(MetaObject):
A pretty string for printing and debugging. A pretty string for printing and debugging.
""" """
return self.variable_type(self, name=name) return self.variable_type(self, None, name=name)
def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type: def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
"""Return a new `Constant` instance of this `Type`. """Return a new `Constant` instance of this `Type`.
......
...@@ -207,7 +207,7 @@ class Linker(ABC): ...@@ -207,7 +207,7 @@ class Linker(ABC):
Examples Examples
-------- --------
x, y = Variable(Double), Variable(Double) x, y = Variable(Double, None), Variable(Double, None)
e = x + y e = x + y
fgraph = FunctionGraph([x, y], [e]) fgraph = FunctionGraph([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace) fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
......
...@@ -415,7 +415,7 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -415,7 +415,7 @@ class ScalarType(CType, HasDataType, HasShape):
return upcast(*[x.dtype for x in [self] + list(others)]) return upcast(*[x.dtype for x in [self] + list(others)])
def make_variable(self, name=None): def make_variable(self, name=None):
return ScalarVariable(self, name=name) return ScalarVariable(self, None, name=name)
def __str__(self): def __str__(self):
return str(self.dtype) return str(self.dtype)
......
...@@ -1483,7 +1483,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1483,7 +1483,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone(self): def clone(self) -> "Scan":
res = copy(self) res = copy(self)
res.fgraph = res.fgraph.clone() res.fgraph = res.fgraph.clone()
return res return res
......
...@@ -939,7 +939,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -939,7 +939,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace( def attempt_scan_inplace(
self, fgraph: FunctionGraph, node: Apply, output_indices: List[int] self, fgraph: FunctionGraph, node: Apply[Scan], output_indices: List[int]
) -> Optional[Apply]: ) -> Optional[Apply]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace. """Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
...@@ -953,7 +953,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -953,7 +953,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
Indices of the outputs to attempt to compute inplace Indices of the outputs to attempt to compute inplace
""" """
op: Scan = cast(Scan, node.op) op = node.op
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.info.n_seqs] ls_begin = node.inputs[: 1 + op.info.n_seqs]
...@@ -1001,7 +1001,10 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1001,7 +1001,10 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op.destroy_map = destroy_map new_op.destroy_map = destroy_map
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs: List[Variable] = new_op(*inputs, return_list=True) new_outs = new_op(*inputs, return_list=True)
assert isinstance(new_outs, list)
try: try:
# TODO FIXME: We need to stop using this approach (i.e. attempt # TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert # in-place replacements and wait for downstream failures to revert
...@@ -1015,7 +1018,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1015,7 +1018,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
remove=[node], remove=[node],
reason="scan_make_inplace", reason="scan_make_inplace",
) )
return new_outs[0].owner return cast(Apply[Scan], new_outs[0].owner)
except InconsistencyError: except InconsistencyError:
# Failed moving output to be computed inplace # Failed moving output to be computed inplace
return None return None
......
...@@ -82,7 +82,7 @@ def load(path, dtype, broadcastable, mmap_mode=None): ...@@ -82,7 +82,7 @@ def load(path, dtype, broadcastable, mmap_mode=None):
Examples Examples
-------- --------
>>> from aesara import * >>> from aesara import *
>>> path = Variable(Generic()) >>> path = Variable(Generic(), None)
>>> x = tensor.load(path, 'int64', (False,)) >>> x = tensor.load(path, 'int64', (False,))
>>> y = x*2 >>> y = x*2
>>> fn = function([path], y) >>> fn = function([path], y)
...@@ -136,7 +136,7 @@ class MPIRecv(Op): ...@@ -136,7 +136,7 @@ class MPIRecv(Op):
self, self,
[], [],
[ [
Variable(Generic()), Variable(Generic(), None),
tensor(self.dtype, shape=self.broadcastable), tensor(self.dtype, shape=self.broadcastable),
], ],
) )
...@@ -222,7 +222,7 @@ class MPISend(Op): ...@@ -222,7 +222,7 @@ class MPISend(Op):
self.tag = tag self.tag = tag
def make_node(self, data): def make_node(self, data):
return Apply(self, [data], [Variable(Generic()), data.type()]) return Apply(self, [data], [Variable(Generic(), None), data.type()])
view_map = {1: [0]} view_map = {1: [0]}
...@@ -259,7 +259,7 @@ class MPISendWait(Op): ...@@ -259,7 +259,7 @@ class MPISendWait(Op):
self.tag = tag self.tag = tag
def make_node(self, request, data): def make_node(self, request, data):
return Apply(self, [request, data], [Variable(Generic())]) return Apply(self, [request, data], [Variable(Generic(), None)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
request = inp[0] request = inp[0]
......
...@@ -3,13 +3,13 @@ import traceback as tb ...@@ -3,13 +3,13 @@ import traceback as tb
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from numbers import Number from numbers import Number
from typing import Optional from typing import Optional, TypeVar
import numpy as np import numpy as np
from aesara import tensor as at from aesara import tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, OptionalApplyType, Variable
from aesara.graph.utils import MetaType from aesara.graph.utils import MetaType
from aesara.scalar import ComplexError, IntegerDivisionError from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor import _get_vector_length, as_tensor_variable
...@@ -18,6 +18,9 @@ from aesara.tensor.type import TensorType ...@@ -18,6 +18,9 @@ from aesara.tensor.type import TensorType
from aesara.tensor.utils import hash_from_ndarray from aesara.tensor.utils import hash_from_ndarray
_TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType)
class _tensor_py_operators: class _tensor_py_operators:
def __abs__(self): def __abs__(self):
return at.math.abs(self) return at.math.abs(self)
...@@ -811,14 +814,22 @@ class _tensor_py_operators: ...@@ -811,14 +814,22 @@ class _tensor_py_operators:
return at.extra_ops.compress(self, a, axis=axis) return at.extra_ops.compress(self, a, axis=axis)
class TensorVariable(_tensor_py_operators, Variable): class TensorVariable(
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]
):
""" """
Subclass to add the tensor operators to the basic `Variable` class. Subclass to add the tensor operators to the basic `Variable` class.
""" """
def __init__(self, type, owner=None, index=None, name=None): def __init__(
super().__init__(type, owner=owner, index=index, name=name) self,
type: _TensorTypeType,
owner: OptionalApplyType,
index=None,
name=None,
):
super().__init__(type, owner, index=index, name=name)
if config.warn_float64 != "ignore" and type.dtype == "float64": if config.warn_float64 != "ignore" and type.dtype == "float64":
msg = ( msg = (
"You are creating a TensorVariable " "You are creating a TensorVariable "
...@@ -979,10 +990,10 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]: ...@@ -979,10 +990,10 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]:
return None return None
class TensorConstant(TensorVariable, Constant): class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
"""Subclass to add the tensor operators to the basic `Constant` class.""" """Subclass to add the tensor operators to the basic `Constant` class."""
def __init__(self, type, data, name=None): def __init__(self, type: _TensorTypeType, data, name=None):
data_shape = np.shape(data) data_shape = np.shape(data)
if len(data_shape) != type.ndim or any( if len(data_shape) != type.ndim or any(
......
...@@ -65,8 +65,8 @@ Example: ...@@ -65,8 +65,8 @@ Example:
#... #...
def make_node(self, x, y): def make_node(self, x, y):
# note 1: constant, int64 and ScalarType are defined in aesara.scalar # note 1: constant, int64 and ScalarType are defined in aesara.scalar
# note 2: constant(x) is equivalent to Constant(type = int64, data = x) # note 2: constant(x) is equivalent to Constant(type=int64, data=x)
# note 3: the call int64() is equivalent to Variable(type = int64) or Variable(type = ScalarType(dtype = 'int64')) # note 3: the call int64() is equivalent to Variable(type=int64, None) or Variable(type=ScalarType(dtype = 'int64'), None)
if isinstance(x, int): if isinstance(x, int):
x = constant(x) x = constant(x)
elif not isinstance(x, Variable) or not x.type == int64: elif not isinstance(x, Variable) or not x.type == int64:
......
...@@ -339,7 +339,7 @@ class TestAutoName: ...@@ -339,7 +339,7 @@ class TestAutoName:
autoname_id = next(Variable.__count__) autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id) Variable.__count__ = count(autoname_id)
r1 = TensorType(dtype="int32", shape=())("myvar") r1 = TensorType(dtype="int32", shape=())("myvar")
r2 = TensorVariable(TensorType(dtype="int32", shape=())) r2 = TensorVariable(TensorType(dtype="int32", shape=()), None)
r3 = shared(np.random.standard_normal((3, 4))) r3 = shared(np.random.standard_normal((3, 4)))
assert r1.auto_name == "auto_" + str(autoname_id) assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
......
...@@ -17,7 +17,7 @@ class TestLoadTensor: ...@@ -17,7 +17,7 @@ class TestLoadTensor:
np.save(self.filename, self.data) np.save(self.filename, self.data)
def test_basic(self): def test_basic(self):
path = Variable(Generic()) path = Variable(Generic(), None)
# Not specifying mmap_mode defaults to None, and the data is # Not specifying mmap_mode defaults to None, and the data is
# copied into main memory # copied into main memory
x = load(path, "int32", (False,)) x = load(path, "int32", (False,))
...@@ -29,13 +29,13 @@ class TestLoadTensor: ...@@ -29,13 +29,13 @@ class TestLoadTensor:
# Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei # Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei
# the output array may be modified inplace, and that should not # the output array may be modified inplace, and that should not
# modify the original file. # modify the original file.
path = Variable(Generic()) path = Variable(Generic(), None)
for mmap_mode in ("r+", "r", "w+", "toto"): for mmap_mode in ("r+", "r", "w+", "toto"):
with pytest.raises(ValueError): with pytest.raises(ValueError):
load(path, "int32", (False,), mmap_mode) load(path, "int32", (False,), mmap_mode)
def test1(self): def test1(self):
path = Variable(Generic()) path = Variable(Generic(), None)
# 'c' means "copy-on-write", which allow the array to be overwritten # 'c' means "copy-on-write", which allow the array to be overwritten
# by an inplace Op in the graph, without modifying the underlying # by an inplace Op in the graph, without modifying the underlying
# file. # file.
...@@ -48,7 +48,7 @@ class TestLoadTensor: ...@@ -48,7 +48,7 @@ class TestLoadTensor:
assert (fn(self.filename) == (self.data**2).sum()).all() assert (fn(self.filename) == (self.data**2).sum()).all()
def test_memmap(self): def test_memmap(self):
path = Variable(Generic()) path = Variable(Generic(), None)
x = load(path, "int32", (False,), mmap_mode="c") x = load(path, "int32", (False,), mmap_mode="c")
fn = function([path], x) fn = function([path], x)
assert type(fn(self.filename)) == np.core.memmap assert type(fn(self.filename)) == np.core.memmap
......
...@@ -63,7 +63,7 @@ def test_shape_basic(): ...@@ -63,7 +63,7 @@ def test_shape_basic():
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType())) s = shape(Variable(MyType(), None))
assert s.type.broadcastable == (False,) assert s.type.broadcastable == (False,)
s = shape(np.array(1)) s = shape(np.array(1))
...@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(3453) rng = np.random.default_rng(3453)
adtens4 = dtensor4() adtens4 = dtensor4()
aivec = TensorVariable(TensorType("int64", (4,))) aivec = TensorVariable(TensorType("int64", (4,)), None)
aivec_val = [3, 4, 2, 5] aivec_val = [3, 4, 2, 5]
adtens4_val = rng.random(aivec_val) adtens4_val = rng.random(aivec_val)
self._compile_and_check( self._compile_and_check(
......
...@@ -234,7 +234,7 @@ def test__getitem__newaxis(x, indices, new_order): ...@@ -234,7 +234,7 @@ def test__getitem__newaxis(x, indices, new_order):
def test_fixed_shape_variable_basic(): def test_fixed_shape_variable_basic():
x = TensorVariable(TensorType("int64", (4,))) x = TensorVariable(TensorType("int64", (4,)), None)
assert isinstance(x.shape, Constant) assert isinstance(x.shape, Constant)
assert np.array_equal(x.shape.data, (4,)) assert np.array_equal(x.shape.data, (4,))
...@@ -246,11 +246,11 @@ def test_fixed_shape_variable_basic(): ...@@ -246,11 +246,11 @@ def test_fixed_shape_variable_basic():
def test_get_vector_length(): def test_get_vector_length():
x = TensorVariable(TensorType("int64", (4,))) x = TensorVariable(TensorType("int64", (4,)), None)
res = get_vector_length(x) res = get_vector_length(x)
assert res == 4 assert res == 4
x = TensorVariable(TensorType("int64", (None,))) x = TensorVariable(TensorType("int64", (None,)), None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_vector_length(x) get_vector_length(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论