提交 9665120e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Enable type checking for NumPy types

上级 bb40791b
...@@ -51,5 +51,6 @@ repos: ...@@ -51,5 +51,6 @@ repos:
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
- numpy>=1.20
- types-filelock - types-filelock
- types-setuptools - types-setuptools
...@@ -1671,14 +1671,14 @@ def equal_computations( ...@@ -1671,14 +1671,14 @@ def equal_computations(
for x, y in zip(xs, ys): for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable): if not isinstance(x, Variable) and not isinstance(y, Variable):
return cast(bool, np.array_equal(x, y)) return np.array_equal(x, y)
if not isinstance(x, Variable): if not isinstance(x, Variable):
if isinstance(y, Constant): if isinstance(y, Constant):
return cast(bool, np.array_equal(y.data, x)) return np.array_equal(y.data, x)
return False return False
if not isinstance(y, Variable): if not isinstance(y, Variable):
if isinstance(x, Constant): if isinstance(x, Constant):
return cast(bool, np.array_equal(x.data, y)) return np.array_equal(x.data, y)
return False return False
if x.owner and not y.owner: if x.owner and not y.owner:
return False return False
......
...@@ -544,25 +544,19 @@ class ExternalCOp(COp): ...@@ -544,25 +544,19 @@ class ExternalCOp(COp):
vname = variable_names[i] vname = variable_names[i]
macro_name = "DTYPE_" + vname macro_items = (f"DTYPE_{vname}", f"npy_{v.type.dtype}")
macro_value = "npy_" + v.type.dtype define_macros.append(define_template % macro_items)
undef_macros.append(undef_template % macro_items[0])
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
d = np.dtype(v.type.dtype) d = np.dtype(v.type.dtype)
macro_name = "TYPENUM_" + vname macro_items_2 = (f"TYPENUM_{vname}", d.num)
macro_value = d.num define_macros.append(define_template % macro_items_2)
undef_macros.append(undef_template % macro_items_2[0])
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_name = "ITEMSIZE_" + vname
macro_value = d.itemsize
define_macros.append(define_template % (macro_name, macro_value)) macro_items_3 = (f"ITEMSIZE_{vname}", d.itemsize)
undef_macros.append(undef_template % macro_name) define_macros.append(define_template % macro_items_3)
undef_macros.append(undef_template % macro_items_3[0])
# Generate a macro to mark code as being apply-specific # Generate a macro to mark code as being apply-specific
define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}")) define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}"))
......
...@@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str] ...@@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
def debugprint( def debugprint(
obj: Union[ graph_like: Union[
Union[Variable, Apply, Function, FunctionGraph], Union[Variable, Apply, Function, FunctionGraph],
Sequence[Union[Variable, Apply, Function, FunctionGraph]], Sequence[Union[Variable, Apply, Function, FunctionGraph]],
], ],
...@@ -139,7 +139,7 @@ def debugprint( ...@@ -139,7 +139,7 @@ def debugprint(
Parameters Parameters
---------- ----------
obj graph_like
The object(s) to be printed. The object(s) to be printed.
depth depth
Print graph to this depth (``-1`` for unlimited). Print graph to this depth (``-1`` for unlimited).
...@@ -149,7 +149,7 @@ def debugprint( ...@@ -149,7 +149,7 @@ def debugprint(
When `file` extends `TextIO`, print to it; when `file` is When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to equal to ``"str"``, return a string; when `file` is ``None``, print to
`sys.stdout`. `sys.stdout`.
ids id_type
Determines the type of identifier used for `Variable`\s: Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value, - ``"id"``: print the python id value,
- ``"int"``: print integer character, - ``"int"``: print integer character,
...@@ -213,12 +213,12 @@ def debugprint( ...@@ -213,12 +213,12 @@ def debugprint(
topo_orders: List[Optional[List[Apply]]] = [] topo_orders: List[Optional[List[Apply]]] = []
storage_maps: List[Optional[StorageMapType]] = [] storage_maps: List[Optional[StorageMapType]] = []
if isinstance(obj, (list, tuple, set)): if isinstance(graph_like, (list, tuple, set)):
lobj = obj graphs = graph_like
else: else:
lobj = [obj] graphs = (graph_like,)
for obj in lobj: for obj in graphs:
if isinstance(obj, Variable): if isinstance(obj, Variable):
outputs_to_print.append(obj) outputs_to_print.append(obj)
profile_list.append(None) profile_list.append(None)
......
"""Symbolic tensor types and constructor functions.""" """Symbolic tensor types and constructor functions."""
from functools import singledispatch from functools import singledispatch
from typing import Any, Callable, NoReturn, Optional, Union from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Sequence, Union
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op from aesara.graph.op import Op
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
TensorLike = Union[Variable, Sequence[Variable], "ArrayLike"]
def as_tensor_variable( def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs x: TensorLike, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> "TensorVariable": ) -> "TensorVariable":
"""Convert `x` into an equivalent `TensorVariable`. """Convert `x` into an equivalent `TensorVariable`.
...@@ -44,12 +51,12 @@ def as_tensor_variable( ...@@ -44,12 +51,12 @@ def as_tensor_variable(
@singledispatch @singledispatch
def _as_tensor_variable( def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs x: TensorLike, name: Optional[str], ndim: Optional[int], **kwargs
) -> "TensorVariable": ) -> "TensorVariable":
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.") raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")
def get_vector_length(v: Any): def get_vector_length(v: TensorLike) -> int:
"""Return the run-time length of a symbolic vector, when possible. """Return the run-time length of a symbolic vector, when possible.
Parameters Parameters
...@@ -80,13 +87,13 @@ def get_vector_length(v: Any): ...@@ -80,13 +87,13 @@ def get_vector_length(v: Any):
@singledispatch @singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable): def _get_vector_length(op: Union[Op, Variable], var: Variable) -> int:
"""`Op`-based dispatch for `get_vector_length`.""" """`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
@_get_vector_length.register(Constant) @_get_vector_length.register(Constant)
def _get_vector_length_Constant(var_inst, var): def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
return len(var.data) return len(var.data)
......
...@@ -28,7 +28,7 @@ class RandomType(Type[T]): ...@@ -28,7 +28,7 @@ class RandomType(Type[T]):
@staticmethod @staticmethod
def may_share_memory(a: T, b: T): def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
class RandomStateType(RandomType[np.random.RandomState]): class RandomStateType(RandomType[np.random.RandomState]):
......
...@@ -83,6 +83,7 @@ warn_unreachable = True ...@@ -83,6 +83,7 @@ warn_unreachable = True
show_error_codes = True show_error_codes = True
allow_redefinition = False allow_redefinition = False
files = aesara,tests files = aesara,tests
plugins = numpy.typing.mypy_plugin
[mypy-versioneer] [mypy-versioneer]
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论