提交 ed62da3e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.scalar.basic

上级 470b9d60
......@@ -16,7 +16,7 @@ from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Dict, Mapping, Optional, Tuple, Type, Union
from typing import Any, Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np
from typing_extensions import TypeAlias
......@@ -42,13 +42,6 @@ from aesara.utils import (
)
builtin_bool = bool
builtin_complex = complex
builtin_int = int
builtin_float = float
# We capture the builtins that we are going to replace to follow the numpy API
_abs = builtins.abs
......@@ -144,7 +137,7 @@ class NumpyAutocaster:
def __call__(self, x):
# Make sure we only deal with scalars.
assert isinstance(x, (int, builtin_float)) or (
assert isinstance(x, (int, builtins.float)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
......@@ -260,7 +253,7 @@ def convert(x, dtype=None):
# This is to imitate numpy behavior which tries to fit
# bigger numbers into a uint64.
x_ = _asarray(x, dtype="uint64")
elif isinstance(x, builtin_float):
elif isinstance(x, builtins.float):
x_ = autocast_float(x)
elif isinstance(x, np.ndarray):
x_ = x
......@@ -681,19 +674,16 @@ class ScalarType(CType, HasDataType):
Scalar = ScalarType
def get_scalar_type(dtype) -> ScalarType:
def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
"""
Return a ScalarType(dtype) object.
This caches objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = ScalarType(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
if dtype not in cache:
cache[dtype] = ScalarType(dtype=dtype)
return cache[dtype]
# Register C code for ViewOp on Scalars.
......@@ -866,7 +856,7 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
def as_scalar(x, name=None) -> ScalarConstant:
def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable:
from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType
......@@ -880,19 +870,15 @@ def as_scalar(x, name=None) -> ScalarConstant:
else:
x = x.outputs[0]
if isinstance(x, Variable):
if isinstance(x.type, ScalarType):
if isinstance(x, ScalarVariable):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a ScalarType.", x, x.type)
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to ScalarType", type(x))
raise TypeError(f"Cannot convert {x} to a scalar type")
return constant(x)
# Easy constructors
ints = apply_across_args(int64)
floats = apply_across_args(float64)
......@@ -1287,7 +1273,7 @@ class BinaryScalarOp(ScalarOp):
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative: Optional[builtins.bool] = None
associative: Optional[builtins.bool] = None
identity: Optional[builtins.bool] = None
identity: Optional[Union[builtins.bool, builtins.float, builtins.int]] = None
"""
For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
......@@ -3863,6 +3849,9 @@ class Real(UnaryScalarOp):
(gz,) = gout
return [complex(gz, 0)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
real = Real(real_out, name="real")
......@@ -3883,6 +3872,9 @@ class Imag(UnaryScalarOp):
else:
return [x.zeros_like(dtype=config.floatX)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
imag = Imag(real_out, name="imag")
......@@ -3920,6 +3912,9 @@ class Angle(UnaryScalarOp):
else:
return [c.zeros_like(dtype=config.floatX)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
angle = Angle(specific_out(float64), name="angle")
......@@ -3946,6 +3941,9 @@ class Complex(BinaryScalarOp):
(gz,) = gout
return [cast(real(gz), x.type.dtype), cast(imag(gz), y.type.dtype)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
complex = Complex(name="complex")
......@@ -3990,6 +3988,9 @@ class ComplexFromPolar(BinaryScalarOp):
gtheta = gz * complex_from_polar(r, -theta)
return [gr, gtheta]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
complex_from_polar = ComplexFromPolar(name="complex_from_polar")
......
......@@ -12,6 +12,7 @@ from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Dict, Iterable, Optional, Tuple, Union
from typing import cast as type_cast
import numpy as np
from numpy.core.multiarray import normalize_axis_index
......@@ -574,6 +575,9 @@ class ScalarFromTensor(COp):
__props__ = ()
def __call__(self, *args, **kwargs) -> ScalarVariable:
return type_cast(ScalarVariable, super().__call__(*args, **kwargs))
def make_node(self, t):
if not isinstance(t.type, TensorType) or t.type.ndim > 0:
raise TypeError("Input must be a scalar `TensorType`")
......
......@@ -139,10 +139,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scalar.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scalar.math]
ignore_errors = True
check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论