提交 b5682ed9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Clear mypy errors in scalar/basic.py

上级 6cf729b9
......@@ -12,7 +12,7 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar!
import builtins
import math
from collections.abc import Callable, Mapping
from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
......@@ -59,7 +59,7 @@ class IntegerDivisionError(Exception):
"""
def upcast(dtype, *dtypes):
def upcast(dtype, *dtypes) -> str:
# This tries to keep data in floatX or lower precision, unless we
# explicitly request a higher precision datatype.
keep_float32 = [
......@@ -899,31 +899,31 @@ complexs64 = apply_across_args(complex64)
complexs128 = apply_across_args(complex128)
def upcast_out(*types):
def upcast_out(*types) -> tuple[ScalarType]:
dtype = ScalarType.upcast(*types)
return (get_scalar_type(dtype),)
def upcast_out_nobool(*types):
def upcast_out_nobool(*types) -> tuple[ScalarType]:
type = upcast_out(*types)
if type[0] == bool:
raise TypeError("bool output not supported")
return type
def upcast_out_min8(*types):
def upcast_out_min8(*types) -> tuple[ScalarType]:
type = upcast_out(*types)
if type[0] == bool:
return (int8,)
return type
def upgrade_to_float(*types):
def upgrade_to_float(*types) -> tuple[ScalarType]:
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv: Mapping[type, type] = {
conv: dict[ScalarType, ScalarType] = {
bool: float32,
int8: float32,
int16: float32,
......@@ -934,12 +934,11 @@ def upgrade_to_float(*types):
uint32: float64,
uint64: float64,
}
return (
get_scalar_type(ScalarType.upcast(*[conv.get(type, type) for type in types])),
)
up = ScalarType.upcast(*[conv.get(type, type) for type in types])
return (get_scalar_type(up),)
def upgrade_to_float64(*types):
def upgrade_to_float64(*types) -> tuple[ScalarType]:
"""
Upgrade any int and float32 to float64 to do as SciPy.
......@@ -947,29 +946,29 @@ def upgrade_to_float64(*types):
return (get_scalar_type("float64"),)
def same_out(type):
def same_out(type: ScalarType) -> tuple[ScalarType]:
return (type,)
def same_out_nobool(type):
def same_out_nobool(type: ScalarType) -> tuple[ScalarType]:
if type == bool:
raise TypeError("bool input not supported")
return (type,)
def same_out_min8(type):
def same_out_min8(type: ScalarType) -> tuple[ScalarType]:
if type == bool:
return (int8,)
return (type,)
def upcast_out_no_complex(*types):
def upcast_out_no_complex(*types) -> tuple[ScalarType]:
if any(type in complex_types for type in types):
raise TypeError("complex type are not supported")
return (get_scalar_type(dtype=ScalarType.upcast(*types)),)
def same_out_float_only(type):
def same_out_float_only(type) -> tuple[ScalarType]:
if type not in float_types:
raise TypeError("only float type are supported")
return (type,)
......
......@@ -11,7 +11,6 @@ pytensor/link/numba/dispatch/elemwise.py
pytensor/link/numba/dispatch/scan.py
pytensor/printing.py
pytensor/raise_op.py
pytensor/scalar/basic.py
pytensor/sparse/basic.py
pytensor/sparse/type.py
pytensor/tensor/basic.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论