提交 4a539e47 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Type by the underlying data's type

上级 01d20497
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Optional, Text, Tuple, TypeVar, Union from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
...@@ -11,7 +11,7 @@ from aesara.graph.utils import MetaObject ...@@ -11,7 +11,7 @@ from aesara.graph.utils import MetaObject
D = TypeVar("D") D = TypeVar("D")
class Type(MetaObject): class Type(MetaObject, Generic[D]):
""" """
Interface specification for variable type instances. Interface specification for variable type instances.
...@@ -77,8 +77,8 @@ class Type(MetaObject): ...@@ -77,8 +77,8 @@ class Type(MetaObject):
@abstractmethod @abstractmethod
def filter( def filter(
self, data: D, strict: bool = False, allow_downcast: Optional[bool] = None self, data: Any, strict: bool = False, allow_downcast: Optional[bool] = None
) -> Union[D, Any]: ) -> D:
"""Return data or an appropriately wrapped/converted data. """Return data or an appropriately wrapped/converted data.
Subclass implementations should raise a TypeError exception if Subclass implementations should raise a TypeError exception if
...@@ -103,7 +103,7 @@ class Type(MetaObject): ...@@ -103,7 +103,7 @@ class Type(MetaObject):
def filter_inplace( def filter_inplace(
self, self,
value: D, value: Any,
storage: Any, storage: Any,
strict: bool = False, strict: bool = False,
allow_downcast: Optional[bool] = None, allow_downcast: Optional[bool] = None,
...@@ -212,7 +212,7 @@ class Type(MetaObject): ...@@ -212,7 +212,7 @@ class Type(MetaObject):
""" """
return self.constant_type(type=self, data=value, name=name) return self.constant_type(type=self, data=value, name=name)
def clone(self, *args, **kwargs): def clone(self, *args, **kwargs) -> "Type":
"""Clone a copy of this type with the given arguments/keyword values, if any.""" """Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs) return type(self)(*args, **kwargs)
...@@ -228,7 +228,7 @@ class Type(MetaObject): ...@@ -228,7 +228,7 @@ class Type(MetaObject):
return utils.add_tag_trace(self.make_variable(name)) return utils.add_tag_trace(self.make_variable(name))
@classmethod @classmethod
def values_eq(cls, a: "Type", b: "Type") -> bool: def values_eq(cls, a: D, b: D) -> bool:
"""Return ``True`` if `a` and `b` can be considered exactly equal. """Return ``True`` if `a` and `b` can be considered exactly equal.
`a` and `b` are assumed to be valid values of this `Type`. `a` and `b` are assumed to be valid values of this `Type`.
...@@ -237,7 +237,7 @@ class Type(MetaObject): ...@@ -237,7 +237,7 @@ class Type(MetaObject):
return a == b return a == b
@classmethod @classmethod
def values_eq_approx(cls, a: Any, b: Any): def values_eq_approx(cls, a: D, b: D) -> bool:
"""Return ``True`` if `a` and `b` can be considered approximately equal. """Return ``True`` if `a` and `b` can be considered approximately equal.
This function is used by Aesara debugging tools to decide This function is used by Aesara debugging tools to decide
......
import ctypes import ctypes
import platform import platform
import re import re
from typing import TypeVar
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -8,7 +9,11 @@ from aesara.link.c.interface import CLinkerType ...@@ -8,7 +9,11 @@ from aesara.link.c.interface import CLinkerType
from aesara.utils import Singleton from aesara.utils import Singleton
class CType(Type, CLinkerType): D = TypeVar("D")
T = TypeVar("T", bound=Type)
class CType(Type[D], CLinkerType):
"""Convenience wrapper combining `Type` and `CLinkerType`. """Convenience wrapper combining `Type` and `CLinkerType`.
Aesara comes with several subclasses of such as: Aesara comes with several subclasses of such as:
...@@ -120,7 +125,7 @@ if platform.python_implementation() != "PyPy": ...@@ -120,7 +125,7 @@ if platform.python_implementation() != "PyPy":
).value ).value
class CDataType(CType): class CDataType(CType[D]):
""" """
Represents opaque C data to be passed around. The intent is to Represents opaque C data to be passed around. The intent is to
ease passing arbitrary data between ops C code. ease passing arbitrary data between ops C code.
...@@ -286,7 +291,7 @@ void _capsule_destructor(PyObject *o) { ...@@ -286,7 +291,7 @@ void _capsule_destructor(PyObject *o) {
self.version = None self.version = None
class CDataTypeConstant(Constant): class CDataTypeConstant(Constant[T]):
def merge_signature(self): def merge_signature(self):
# We don't want to merge constants that don't point to the # We don't want to merge constants that don't point to the
# same object. # same object.
......
from typing import Generic, TypeVar from typing import TypeVar
import numpy as np import numpy as np
...@@ -23,7 +23,7 @@ gen_states_keys = { ...@@ -23,7 +23,7 @@ gen_states_keys = {
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"} numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type, Generic[T]): class RandomType(Type[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@staticmethod @staticmethod
......
...@@ -48,7 +48,7 @@ dtype_specs_map = { ...@@ -48,7 +48,7 @@ dtype_specs_map = {
} }
class TensorType(CType, HasDataType, HasShape): class TensorType(CType[np.ndarray], HasDataType, HasShape):
r"""Symbolic `Type` representing `numpy.ndarray`\s.""" r"""Symbolic `Type` representing `numpy.ndarray`\s."""
__props__: Tuple[str, ...] = ("dtype", "shape") __props__: Tuple[str, ...] = ("dtype", "shape")
...@@ -108,7 +108,9 @@ class TensorType(CType, HasDataType, HasShape): ...@@ -108,7 +108,9 @@ class TensorType(CType, HasDataType, HasShape):
self.name = name self.name = name
self.numpy_dtype = np.dtype(self.dtype) self.numpy_dtype = np.dtype(self.dtype)
def clone(self, dtype=None, shape=None, broadcastable=None, **kwargs): def clone(
self, dtype=None, shape=None, broadcastable=None, **kwargs
) -> "TensorType":
if broadcastable is not None: if broadcastable is not None:
warnings.warn( warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.", "The `broadcastable` keyword is deprecated; use `shape`.",
......
...@@ -51,7 +51,7 @@ class MakeSlice(Op): ...@@ -51,7 +51,7 @@ class MakeSlice(Op):
make_slice = MakeSlice() make_slice = MakeSlice()
class SliceType(Type): class SliceType(Type[slice]):
def clone(self, **kwargs): def clone(self, **kwargs):
return type(self)() return type(self)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论