提交 ed62da3e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.scalar.basic

上级 470b9d60
...@@ -16,7 +16,7 @@ from collections.abc import Callable ...@@ -16,7 +16,7 @@ from collections.abc import Callable
from copy import copy from copy import copy
from itertools import chain from itertools import chain
from textwrap import dedent from textwrap import dedent
from typing import Dict, Mapping, Optional, Tuple, Type, Union from typing import Any, Dict, Mapping, Optional, Tuple, Type, Union
import numpy as np import numpy as np
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
...@@ -42,13 +42,6 @@ from aesara.utils import ( ...@@ -42,13 +42,6 @@ from aesara.utils import (
) )
builtin_bool = bool
builtin_complex = complex
builtin_int = int
builtin_float = float
# We capture the builtins that we are going to replace to follow the numpy API
_abs = builtins.abs _abs = builtins.abs
...@@ -144,7 +137,7 @@ class NumpyAutocaster: ...@@ -144,7 +137,7 @@ class NumpyAutocaster:
def __call__(self, x): def __call__(self, x):
# Make sure we only deal with scalars. # Make sure we only deal with scalars.
assert isinstance(x, (int, builtin_float)) or ( assert isinstance(x, (int, builtins.float)) or (
isinstance(x, np.ndarray) and x.ndim == 0 isinstance(x, np.ndarray) and x.ndim == 0
) )
...@@ -260,7 +253,7 @@ def convert(x, dtype=None): ...@@ -260,7 +253,7 @@ def convert(x, dtype=None):
# This is to imitate numpy behavior which tries to fit # This is to imitate numpy behavior which tries to fit
# bigger numbers into a uint64. # bigger numbers into a uint64.
x_ = _asarray(x, dtype="uint64") x_ = _asarray(x, dtype="uint64")
elif isinstance(x, builtin_float): elif isinstance(x, builtins.float):
x_ = autocast_float(x) x_ = autocast_float(x)
elif isinstance(x, np.ndarray): elif isinstance(x, np.ndarray):
x_ = x x_ = x
...@@ -681,19 +674,16 @@ class ScalarType(CType, HasDataType): ...@@ -681,19 +674,16 @@ class ScalarType(CType, HasDataType):
Scalar = ScalarType Scalar = ScalarType
def get_scalar_type(dtype) -> ScalarType: def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType:
""" """
Return a ScalarType(dtype) object. Return a ScalarType(dtype) object.
This caches objects to save allocation and run time. This caches objects to save allocation and run time.
""" """
if dtype not in get_scalar_type.cache: if dtype not in cache:
get_scalar_type.cache[dtype] = ScalarType(dtype=dtype) cache[dtype] = ScalarType(dtype=dtype)
return get_scalar_type.cache[dtype] return cache[dtype]
get_scalar_type.cache = {}
# Register C code for ViewOp on Scalars. # Register C code for ViewOp on Scalars.
...@@ -866,7 +856,7 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: ...@@ -866,7 +856,7 @@ def constant(x, name=None, dtype=None) -> ScalarConstant:
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name) return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
def as_scalar(x, name=None) -> ScalarConstant: def as_scalar(x: Any, name: Optional[str] = None) -> ScalarVariable:
from aesara.tensor.basic import scalar_from_tensor from aesara.tensor.basic import scalar_from_tensor
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
...@@ -880,19 +870,15 @@ def as_scalar(x, name=None) -> ScalarConstant: ...@@ -880,19 +870,15 @@ def as_scalar(x, name=None) -> ScalarConstant:
else: else:
x = x.outputs[0] x = x.outputs[0]
if isinstance(x, Variable): if isinstance(x, Variable):
if isinstance(x.type, ScalarType): if isinstance(x, ScalarVariable):
return x return x
elif isinstance(x.type, TensorType) and x.ndim == 0: elif isinstance(x.type, TensorType) and x.type.ndim == 0:
return scalar_from_tensor(x) return scalar_from_tensor(x)
else: else:
raise TypeError("Variable type field must be a ScalarType.", x, x.type) raise TypeError(f"Cannot convert {x} to a scalar type")
try:
return constant(x)
except TypeError:
raise TypeError(f"Cannot convert {x} to ScalarType", type(x))
return constant(x)
# Easy constructors
ints = apply_across_args(int64) ints = apply_across_args(int64)
floats = apply_across_args(float64) floats = apply_across_args(float64)
...@@ -1287,7 +1273,7 @@ class BinaryScalarOp(ScalarOp): ...@@ -1287,7 +1273,7 @@ class BinaryScalarOp(ScalarOp):
# - `associative`: whether op(op(a, b), c) == op(a, op(b, c)) # - `associative`: whether op(op(a, b), c) == op(a, op(b, c))
commutative: Optional[builtins.bool] = None commutative: Optional[builtins.bool] = None
associative: Optional[builtins.bool] = None associative: Optional[builtins.bool] = None
identity: Optional[builtins.bool] = None identity: Optional[Union[builtins.bool, builtins.float, builtins.int]] = None
""" """
For an associative operation, the identity object corresponds to the neutral For an associative operation, the identity object corresponds to the neutral
element. For instance, it will be ``0`` for addition, ``1`` for multiplication, element. For instance, it will be ``0`` for addition, ``1`` for multiplication,
...@@ -3863,6 +3849,9 @@ class Real(UnaryScalarOp): ...@@ -3863,6 +3849,9 @@ class Real(UnaryScalarOp):
(gz,) = gout (gz,) = gout
return [complex(gz, 0)] return [complex(gz, 0)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
real = Real(real_out, name="real") real = Real(real_out, name="real")
...@@ -3883,6 +3872,9 @@ class Imag(UnaryScalarOp): ...@@ -3883,6 +3872,9 @@ class Imag(UnaryScalarOp):
else: else:
return [x.zeros_like(dtype=config.floatX)] return [x.zeros_like(dtype=config.floatX)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
imag = Imag(real_out, name="imag") imag = Imag(real_out, name="imag")
...@@ -3920,6 +3912,9 @@ class Angle(UnaryScalarOp): ...@@ -3920,6 +3912,9 @@ class Angle(UnaryScalarOp):
else: else:
return [c.zeros_like(dtype=config.floatX)] return [c.zeros_like(dtype=config.floatX)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
angle = Angle(specific_out(float64), name="angle") angle = Angle(specific_out(float64), name="angle")
...@@ -3946,6 +3941,9 @@ class Complex(BinaryScalarOp): ...@@ -3946,6 +3941,9 @@ class Complex(BinaryScalarOp):
(gz,) = gout (gz,) = gout
return [cast(real(gz), x.type.dtype), cast(imag(gz), y.type.dtype)] return [cast(real(gz), x.type.dtype), cast(imag(gz), y.type.dtype)]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
complex = Complex(name="complex") complex = Complex(name="complex")
...@@ -3990,6 +3988,9 @@ class ComplexFromPolar(BinaryScalarOp): ...@@ -3990,6 +3988,9 @@ class ComplexFromPolar(BinaryScalarOp):
gtheta = gz * complex_from_polar(r, -theta) gtheta = gz * complex_from_polar(r, -theta)
return [gr, gtheta] return [gr, gtheta]
def c_code(self, *args, **kwargs):
raise NotImplementedError()
complex_from_polar = ComplexFromPolar(name="complex_from_polar") complex_from_polar = ComplexFromPolar(name="complex_from_polar")
......
...@@ -12,6 +12,7 @@ from collections.abc import Sequence ...@@ -12,6 +12,7 @@ from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Dict, Iterable, Optional, Tuple, Union from typing import Dict, Iterable, Optional, Tuple, Union
from typing import cast as type_cast
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
...@@ -574,6 +575,9 @@ class ScalarFromTensor(COp): ...@@ -574,6 +575,9 @@ class ScalarFromTensor(COp):
__props__ = () __props__ = ()
def __call__(self, *args, **kwargs) -> ScalarVariable:
return type_cast(ScalarVariable, super().__call__(*args, **kwargs))
def make_node(self, t): def make_node(self, t):
if not isinstance(t.type, TensorType) or t.type.ndim > 0: if not isinstance(t.type, TensorType) or t.type.ndim > 0:
raise TypeError("Input must be a scalar `TensorType`") raise TypeError("Input must be a scalar `TensorType`")
......
...@@ -139,10 +139,6 @@ check_untyped_defs = False ...@@ -139,10 +139,6 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.scalar.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.scalar.math] [mypy-aesara.scalar.math]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论