提交 4a539e47 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Type by the underlying data's type

上级 01d20497
from abc import abstractmethod
from typing import Any, Optional, Text, Tuple, TypeVar, Union
from typing import Any, Generic, Optional, Text, Tuple, TypeVar, Union
from typing_extensions import TypeAlias
......@@ -11,7 +11,7 @@ from aesara.graph.utils import MetaObject
D = TypeVar("D")
class Type(MetaObject):
class Type(MetaObject, Generic[D]):
"""
Interface specification for variable type instances.
......@@ -77,8 +77,8 @@ class Type(MetaObject):
@abstractmethod
def filter(
self, data: D, strict: bool = False, allow_downcast: Optional[bool] = None
) -> Union[D, Any]:
self, data: Any, strict: bool = False, allow_downcast: Optional[bool] = None
) -> D:
"""Return data or an appropriately wrapped/converted data.
Subclass implementations should raise a TypeError exception if
......@@ -103,7 +103,7 @@ class Type(MetaObject):
def filter_inplace(
self,
value: D,
value: Any,
storage: Any,
strict: bool = False,
allow_downcast: Optional[bool] = None,
......@@ -212,7 +212,7 @@ class Type(MetaObject):
"""
return self.constant_type(type=self, data=value, name=name)
def clone(self, *args, **kwargs):
def clone(self, *args, **kwargs) -> "Type":
"""Clone a copy of this type with the given arguments/keyword values, if any."""
return type(self)(*args, **kwargs)
......@@ -228,7 +228,7 @@ class Type(MetaObject):
return utils.add_tag_trace(self.make_variable(name))
@classmethod
def values_eq(cls, a: "Type", b: "Type") -> bool:
def values_eq(cls, a: D, b: D) -> bool:
"""Return ``True`` if `a` and `b` can be considered exactly equal.
`a` and `b` are assumed to be valid values of this `Type`.
......@@ -237,7 +237,7 @@ class Type(MetaObject):
return a == b
@classmethod
def values_eq_approx(cls, a: Any, b: Any):
def values_eq_approx(cls, a: D, b: D) -> bool:
"""Return ``True`` if `a` and `b` can be considered approximately equal.
This function is used by Aesara debugging tools to decide
......
import ctypes
import platform
import re
from typing import TypeVar
from aesara.graph.basic import Constant
from aesara.graph.type import Type
......@@ -8,7 +9,11 @@ from aesara.link.c.interface import CLinkerType
from aesara.utils import Singleton
class CType(Type, CLinkerType):
D = TypeVar("D")
T = TypeVar("T", bound=Type)
class CType(Type[D], CLinkerType):
"""Convenience wrapper combining `Type` and `CLinkerType`.
Aesara comes with several subclasses of such as:
......@@ -120,7 +125,7 @@ if platform.python_implementation() != "PyPy":
).value
class CDataType(CType):
class CDataType(CType[D]):
"""
Represents opaque C data to be passed around. The intent is to
ease passing arbitrary data between ops C code.
......@@ -286,7 +291,7 @@ void _capsule_destructor(PyObject *o) {
self.version = None
class CDataTypeConstant(Constant):
class CDataTypeConstant(Constant[T]):
def merge_signature(self):
# We don't want to merge constants that don't point to the
# same object.
......
from typing import Generic, TypeVar
from typing import TypeVar
import numpy as np
......@@ -23,7 +23,7 @@ gen_states_keys = {
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type, Generic[T]):
class RandomType(Type[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@staticmethod
......
......@@ -48,7 +48,7 @@ dtype_specs_map = {
}
class TensorType(CType, HasDataType, HasShape):
class TensorType(CType[np.ndarray], HasDataType, HasShape):
r"""Symbolic `Type` representing `numpy.ndarray`\s."""
__props__: Tuple[str, ...] = ("dtype", "shape")
......@@ -108,7 +108,9 @@ class TensorType(CType, HasDataType, HasShape):
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
def clone(self, dtype=None, shape=None, broadcastable=None, **kwargs):
def clone(
self, dtype=None, shape=None, broadcastable=None, **kwargs
) -> "TensorType":
if broadcastable is not None:
warnings.warn(
"The `broadcastable` keyword is deprecated; use `shape`.",
......
......@@ -51,7 +51,7 @@ class MakeSlice(Op):
make_slice = MakeSlice()
class SliceType(Type):
class SliceType(Type[slice]):
def clone(self, **kwargs):
return type(self)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论