提交 9665120e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Enable type checking for NumPy types

上级 bb40791b
......@@ -51,5 +51,6 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- numpy>=1.20
- types-filelock
- types-setuptools
......@@ -1671,14 +1671,14 @@ def equal_computations(
for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return cast(bool, np.array_equal(x, y))
return np.array_equal(x, y)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return cast(bool, np.array_equal(y.data, x))
return np.array_equal(y.data, x)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return cast(bool, np.array_equal(x.data, y))
return np.array_equal(x.data, y)
return False
if x.owner and not y.owner:
return False
......
......@@ -544,25 +544,19 @@ class ExternalCOp(COp):
vname = variable_names[i]
macro_name = "DTYPE_" + vname
macro_value = "npy_" + v.type.dtype
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_items = (f"DTYPE_{vname}", f"npy_{v.type.dtype}")
define_macros.append(define_template % macro_items)
undef_macros.append(undef_template % macro_items[0])
d = np.dtype(v.type.dtype)
macro_name = "TYPENUM_" + vname
macro_value = d.num
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_name = "ITEMSIZE_" + vname
macro_value = d.itemsize
macro_items_2 = (f"TYPENUM_{vname}", d.num)
define_macros.append(define_template % macro_items_2)
undef_macros.append(undef_template % macro_items_2[0])
define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_items_3 = (f"ITEMSIZE_{vname}", d.itemsize)
define_macros.append(define_template % macro_items_3)
undef_macros.append(undef_template % macro_items_3[0])
# Generate a macro to mark code as being apply-specific
define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}"))
......
......@@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
def debugprint(
obj: Union[
graph_like: Union[
Union[Variable, Apply, Function, FunctionGraph],
Sequence[Union[Variable, Apply, Function, FunctionGraph]],
],
......@@ -139,7 +139,7 @@ def debugprint(
Parameters
----------
obj
graph_like
The object(s) to be printed.
depth
Print graph to this depth (``-1`` for unlimited).
......@@ -149,7 +149,7 @@ def debugprint(
When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to
`sys.stdout`.
ids
id_type
Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
......@@ -213,12 +213,12 @@ def debugprint(
topo_orders: List[Optional[List[Apply]]] = []
storage_maps: List[Optional[StorageMapType]] = []
if isinstance(obj, (list, tuple, set)):
lobj = obj
if isinstance(graph_like, (list, tuple, set)):
graphs = graph_like
else:
lobj = [obj]
graphs = (graph_like,)
for obj in lobj:
for obj in graphs:
if isinstance(obj, Variable):
outputs_to_print.append(obj)
profile_list.append(None)
......
"""Symbolic tensor types and constructor functions."""
from functools import singledispatch
from typing import Any, Callable, NoReturn, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Sequence, Union
from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
TensorLike = Union[Variable, Sequence[Variable], "ArrayLike"]
def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
x: TensorLike, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> "TensorVariable":
"""Convert `x` into an equivalent `TensorVariable`.
......@@ -44,12 +51,12 @@ def as_tensor_variable(
@singledispatch
def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
x: TensorLike, name: Optional[str], ndim: Optional[int], **kwargs
) -> "TensorVariable":
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")
def get_vector_length(v: Any):
def get_vector_length(v: TensorLike) -> int:
"""Return the run-time length of a symbolic vector, when possible.
Parameters
......@@ -80,13 +87,13 @@ def get_vector_length(v: Any):
@singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable):
def _get_vector_length(op: Union[Op, Variable], var: Variable) -> int:
"""`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined")
@_get_vector_length.register(Constant)
def _get_vector_length_Constant(var_inst, var):
def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
return len(var.data)
......
......@@ -28,7 +28,7 @@ class RandomType(Type[T]):
@staticmethod
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
class RandomStateType(RandomType[np.random.RandomState]):
......
......@@ -83,6 +83,7 @@ warn_unreachable = True
show_error_codes = True
allow_redefinition = False
files = aesara,tests
plugins = numpy.typing.mypy_plugin
[mypy-versioneer]
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论