提交 01d20497 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Variable type by Type and Apply

上级 de9ad202
......@@ -6,6 +6,7 @@ from copy import copy
from itertools import count
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Deque,
......@@ -47,6 +48,9 @@ if TYPE_CHECKING:
OpType = TypeVar("OpType", bound="Op")
OptionalApplyType = TypeVar("OptionalApplyType", None, "Apply", covariant=True)
_TypeType = TypeVar("_TypeType", bound="Type")
_IdType = TypeVar("_IdType", bound=Hashable)
T = TypeVar("T", bound="Node")
NoParams = object()
......@@ -61,7 +65,6 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
type: "Type"
name: Optional[str]
def get_parents(self):
......@@ -110,7 +113,10 @@ class Apply(Node, Generic[OpType]):
"""
def __init__(
self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"]
self,
op: OpType,
inputs: Sequence["Variable"],
outputs: Sequence["Variable"],
):
if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a sequence type")
......@@ -309,7 +315,7 @@ class Apply(Node, Generic[OpType]):
return self.op.params_type
class Variable(Node):
class Variable(Node, Generic[_TypeType, OptionalApplyType]):
r"""
A :term:`Variable` is a node in an expression graph that represents a
variable.
......@@ -407,10 +413,10 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0)
_owner: Optional[Apply]
_owner: OptionalApplyType
@property
def owner(self) -> Optional[Apply]:
def owner(self) -> OptionalApplyType:
return self._owner
@owner.setter
......@@ -427,30 +433,31 @@ class Variable(Node):
def __init__(
self,
type,
owner: Optional[Apply] = None,
type: _TypeType,
owner: OptionalApplyType,
index: Optional[int] = None,
name: Optional[str] = None,
):
) -> None:
super().__init__()
self.tag = ValidatingScratchpad("test_value", type.filter)
self.type = type
self._owner = owner
if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
raise TypeError("owner must be an Apply instance")
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
raise TypeError("index must be an int")
self.index = index
if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name)
raise TypeError("name must be a string")
self.name = name
self.auto_name = "auto_" + str(next(self.__count__))
self.auto_name = f"auto_{next(self.__count__)}"
def get_test_value(self):
"""Get the test value.
......@@ -516,7 +523,6 @@ class Variable(Node):
Tags and names are copied to the returned instance.
"""
# return copy(self)
cp = self.__class__(self.type, None, None, self.name)
cp.tag = copy(self.tag)
return cp
......@@ -612,11 +618,11 @@ class Variable(Node):
return d
class AtomicVariable(Variable):
class AtomicVariable(Variable[_TypeType, None]):
"""A node type that has no ancestors and should never be considered an input to a graph."""
def __init__(self, type, **kwargs):
super().__init__(type, **kwargs)
def __init__(self, type: _TypeType, **kwargs):
super().__init__(type, None, None, **kwargs)
@abc.abstractmethod
def signature(self):
......@@ -651,13 +657,13 @@ class AtomicVariable(Variable):
raise ValueError("AtomicVariable instances cannot have an index.")
class NominalVariable(AtomicVariable):
class NominalVariable(AtomicVariable[_TypeType]):
"""A variable that enables alpha-equivalent comparisons."""
__instances__: Dict[Hashable, type] = {}
__instances__: Dict[Tuple["Type", Hashable], "NominalVariable"] = {}
def __new__(cls, id, typ, **kwargs):
if (id, typ) not in cls.__instances__:
def __new__(cls, id: _IdType, typ: _TypeType, **kwargs):
if (typ, id) not in cls.__instances__:
var_type = typ.variable_type
type_name = f"Nominal{var_type.__name__}"
......@@ -670,13 +676,13 @@ class NominalVariable(AtomicVariable):
new_type = type(
type_name, (cls, var_type), {"__reduce__": _reduce, "__str__": _str}
)
res = super().__new__(new_type)
res: NominalVariable = super().__new__(new_type)
cls.__instances__[(id, typ)] = res
cls.__instances__[(typ, id)] = res
return cls.__instances__[(id, typ)]
return cls.__instances__[(typ, id)]
def __init__(self, id, typ, **kwargs):
def __init__(self, id: _IdType, typ: _TypeType, **kwargs):
self.id = id
super().__init__(typ, **kwargs)
......@@ -699,11 +705,11 @@ class NominalVariable(AtomicVariable):
def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
def signature(self):
def signature(self) -> Tuple[_TypeType, _IdType]:
return (self.type, self.id)
class Constant(AtomicVariable):
class Constant(AtomicVariable[_TypeType]):
"""A `Variable` with a fixed `data` field.
`Constant` nodes make numerous optimizations possible (e.g. constant
......@@ -718,7 +724,7 @@ class Constant(AtomicVariable):
# __slots__ = ['data']
def __init__(self, type, data, name=None):
def __init__(self, type: _TypeType, data: Any, name: Optional[str] = None):
super().__init__(type, name=name)
self.data = type.filter(data)
add_tag_trace(self)
......
......@@ -197,7 +197,7 @@ class Type(MetaObject):
A pretty string for printing and debugging.
"""
return self.variable_type(self, name=name)
return self.variable_type(self, None, name=name)
def make_constant(self, value: D, name: Optional[Text] = None) -> constant_type:
"""Return a new `Constant` instance of this `Type`.
......
......@@ -207,7 +207,7 @@ class Linker(ABC):
Examples
--------
x, y = Variable(Double), Variable(Double)
x, y = Variable(Double, None), Variable(Double, None)
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
......
......@@ -415,7 +415,7 @@ class ScalarType(CType, HasDataType, HasShape):
return upcast(*[x.dtype for x in [self] + list(others)])
def make_variable(self, name=None):
return ScalarVariable(self, name=name)
return ScalarVariable(self, None, name=name)
def __str__(self):
return str(self.dtype)
......
......@@ -1483,7 +1483,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def inner_outputs(self):
return self.fgraph.outputs
def clone(self):
def clone(self) -> "Scan":
res = copy(self)
res.fgraph = res.fgraph.clone()
return res
......
......@@ -939,7 +939,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(
self, fgraph: FunctionGraph, node: Apply, output_indices: List[int]
self, fgraph: FunctionGraph, node: Apply[Scan], output_indices: List[int]
) -> Optional[Apply]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
......@@ -953,7 +953,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
Indices of the outputs to attempt to compute inplace
"""
op: Scan = cast(Scan, node.op)
op = node.op
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.info.n_seqs]
......@@ -1001,7 +1001,10 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op.destroy_map = destroy_map
# Do not call make_node for test_value
new_outs: List[Variable] = new_op(*inputs, return_list=True)
new_outs = new_op(*inputs, return_list=True)
assert isinstance(new_outs, list)
try:
# TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert
......@@ -1015,7 +1018,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
remove=[node],
reason="scan_make_inplace",
)
return new_outs[0].owner
return cast(Apply[Scan], new_outs[0].owner)
except InconsistencyError:
# Failed moving output to be computed inplace
return None
......
......@@ -82,7 +82,7 @@ def load(path, dtype, broadcastable, mmap_mode=None):
Examples
--------
>>> from aesara import *
>>> path = Variable(Generic())
>>> path = Variable(Generic(), None)
>>> x = tensor.load(path, 'int64', (False,))
>>> y = x*2
>>> fn = function([path], y)
......@@ -136,7 +136,7 @@ class MPIRecv(Op):
self,
[],
[
Variable(Generic()),
Variable(Generic(), None),
tensor(self.dtype, shape=self.broadcastable),
],
)
......@@ -222,7 +222,7 @@ class MPISend(Op):
self.tag = tag
def make_node(self, data):
return Apply(self, [data], [Variable(Generic()), data.type()])
return Apply(self, [data], [Variable(Generic(), None), data.type()])
view_map = {1: [0]}
......@@ -259,7 +259,7 @@ class MPISendWait(Op):
self.tag = tag
def make_node(self, request, data):
return Apply(self, [request, data], [Variable(Generic())])
return Apply(self, [request, data], [Variable(Generic(), None)])
def perform(self, node, inp, out):
request = inp[0]
......
......@@ -3,13 +3,13 @@ import traceback as tb
import warnings
from collections.abc import Iterable
from numbers import Number
from typing import Optional
from typing import Optional, TypeVar
import numpy as np
from aesara import tensor as at
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Constant, OptionalApplyType, Variable
from aesara.graph.utils import MetaType
from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length, as_tensor_variable
......@@ -18,6 +18,9 @@ from aesara.tensor.type import TensorType
from aesara.tensor.utils import hash_from_ndarray
_TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType)
class _tensor_py_operators:
def __abs__(self):
return at.math.abs(self)
......@@ -811,14 +814,22 @@ class _tensor_py_operators:
return at.extra_ops.compress(self, a, axis=axis)
class TensorVariable(_tensor_py_operators, Variable):
class TensorVariable(
_tensor_py_operators, Variable[_TensorTypeType, OptionalApplyType]
):
"""
Subclass to add the tensor operators to the basic `Variable` class.
"""
def __init__(self, type, owner=None, index=None, name=None):
super().__init__(type, owner=owner, index=index, name=name)
def __init__(
self,
type: _TensorTypeType,
owner: OptionalApplyType,
index=None,
name=None,
):
super().__init__(type, owner, index=index, name=name)
if config.warn_float64 != "ignore" and type.dtype == "float64":
msg = (
"You are creating a TensorVariable "
......@@ -979,10 +990,10 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]:
return None
class TensorConstant(TensorVariable, Constant):
class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
"""Subclass to add the tensor operators to the basic `Constant` class."""
def __init__(self, type, data, name=None):
def __init__(self, type: _TensorTypeType, data, name=None):
data_shape = np.shape(data)
if len(data_shape) != type.ndim or any(
......
......@@ -65,8 +65,8 @@ Example:
#...
def make_node(self, x, y):
# note 1: constant, int64 and ScalarType are defined in aesara.scalar
# note 2: constant(x) is equivalent to Constant(type = int64, data = x)
# note 3: the call int64() is equivalent to Variable(type = int64) or Variable(type = ScalarType(dtype = 'int64'))
# note 2: constant(x) is equivalent to Constant(type=int64, data=x)
# note 3: the call int64() is equivalent to Variable(type=int64, None) or Variable(type=ScalarType(dtype = 'int64'), None)
if isinstance(x, int):
x = constant(x)
elif not isinstance(x, Variable) or not x.type == int64:
......
......@@ -339,7 +339,7 @@ class TestAutoName:
autoname_id = next(Variable.__count__)
Variable.__count__ = count(autoname_id)
r1 = TensorType(dtype="int32", shape=())("myvar")
r2 = TensorVariable(TensorType(dtype="int32", shape=()))
r2 = TensorVariable(TensorType(dtype="int32", shape=()), None)
r3 = shared(np.random.standard_normal((3, 4)))
assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1)
......
......@@ -17,7 +17,7 @@ class TestLoadTensor:
np.save(self.filename, self.data)
def test_basic(self):
path = Variable(Generic())
path = Variable(Generic(), None)
# Not specifying mmap_mode defaults to None, and the data is
# copied into main memory
x = load(path, "int32", (False,))
......@@ -29,13 +29,13 @@ class TestLoadTensor:
# Modes 'r+', 'r', and 'w+' cannot work with Aesara, becausei
# the output array may be modified inplace, and that should not
# modify the original file.
path = Variable(Generic())
path = Variable(Generic(), None)
for mmap_mode in ("r+", "r", "w+", "toto"):
with pytest.raises(ValueError):
load(path, "int32", (False,), mmap_mode)
def test1(self):
path = Variable(Generic())
path = Variable(Generic(), None)
# 'c' means "copy-on-write", which allow the array to be overwritten
# by an inplace Op in the graph, without modifying the underlying
# file.
......@@ -48,7 +48,7 @@ class TestLoadTensor:
assert (fn(self.filename) == (self.data**2).sum()).all()
def test_memmap(self):
path = Variable(Generic())
path = Variable(Generic(), None)
x = load(path, "int32", (False,), mmap_mode="c")
fn = function([path], x)
assert type(fn(self.filename)) == np.core.memmap
......
......@@ -63,7 +63,7 @@ def test_shape_basic():
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
s = shape(Variable(MyType()))
s = shape(Variable(MyType(), None))
assert s.type.broadcastable == (False,)
s = shape(np.array(1))
......@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
def test_infer_shape(self):
rng = np.random.default_rng(3453)
adtens4 = dtensor4()
aivec = TensorVariable(TensorType("int64", (4,)))
aivec = TensorVariable(TensorType("int64", (4,)), None)
aivec_val = [3, 4, 2, 5]
adtens4_val = rng.random(aivec_val)
self._compile_and_check(
......
......@@ -234,7 +234,7 @@ def test__getitem__newaxis(x, indices, new_order):
def test_fixed_shape_variable_basic():
x = TensorVariable(TensorType("int64", (4,)))
x = TensorVariable(TensorType("int64", (4,)), None)
assert isinstance(x.shape, Constant)
assert np.array_equal(x.shape.data, (4,))
......@@ -246,11 +246,11 @@ def test_fixed_shape_variable_basic():
def test_get_vector_length():
x = TensorVariable(TensorType("int64", (4,)))
x = TensorVariable(TensorType("int64", (4,)), None)
res = get_vector_length(x)
assert res == 4
x = TensorVariable(TensorType("int64", (None,)))
x = TensorVariable(TensorType("int64", (None,)), None)
with pytest.raises(ValueError):
get_vector_length(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论