提交 63da6d16 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix types in tensor/random/utils.py and tensor/utils.py

上级 6a295b9f
......@@ -250,6 +250,7 @@ def core_VonMisesRV(op, node):
@numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement)
def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node):
assert isinstance(op.signature, str)
[core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1]
core_shape_len = int(core_shape_len_sig)
implicit_arange = op.ndims_params[0] == 0
......
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING
import numpy as np
......@@ -22,7 +22,9 @@ if TYPE_CHECKING:
from pytensor.tensor.random.op import RandomVariable
def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
def params_broadcast_shapes(
param_shapes: Sequence, ndims_params: Sequence[int], use_pytensor: bool = True
) -> list[tuple[int, ...]]:
"""Broadcast parameters that have different dimensions.
Parameters
......@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
Returns
=======
bcast_shapes : list of ndarray
bcast_shapes : list of tuples of ints
The broadcasted values of `params`.
"""
max_fn = maximum if use_pytensor else max
rev_extra_dims = []
rev_extra_dims: list[int] = []
for ndim_param, param_shape in zip(ndims_params, param_shapes):
# We need this in order to use `len`
param_shape = tuple(param_shape)
......@@ -71,7 +73,9 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
return bcast_shapes
def broadcast_params(params, ndims_params):
def broadcast_params(
params: Sequence[np.ndarray | TensorVariable], ndims_params: Sequence[int]
) -> list[np.ndarray]:
"""Broadcast parameters that have different dimensions.
>>> ndims_params = [1, 2]
......@@ -215,7 +219,9 @@ class RandomStream:
self,
seed: int | None = None,
namespace: ModuleType | None = None,
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
rng_ctor: Callable[
[np.random.SeedSequence], np.random.Generator
] = np.random.default_rng,
):
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self
......
......@@ -6,10 +6,11 @@ import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.utils import hash_from_code
def hash_from_ndarray(data):
def hash_from_ndarray(data) -> str:
"""
Return a hash from an ndarray.
......@@ -36,7 +37,9 @@ def hash_from_ndarray(data):
)
def shape_of_variables(fgraph, input_shapes):
def shape_of_variables(
fgraph: FunctionGraph, input_shapes
) -> dict[Variable, tuple[int, ...]]:
"""
Compute the numeric shape of all intermediate variables given input shapes.
......@@ -73,16 +76,14 @@ def shape_of_variables(fgraph, input_shapes):
fgraph.attach_feature(ShapeFeature())
shape_feature = fgraph.shape_feature # type: ignore[attr-defined]
input_dims = [
dimension
for inp in fgraph.inputs
for dimension in fgraph.shape_feature.shape_of[inp]
dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp]
]
output_dims = [
dimension
for shape in fgraph.shape_feature.shape_of.values()
for dimension in shape
dimension for shape in shape_feature.shape_of.values() for dimension in shape
]
compute_shapes = pytensor.function(input_dims, output_dims)
......@@ -100,10 +101,8 @@ def shape_of_variables(fgraph, input_shapes):
sym_to_num_dict = dict(zip(output_dims, numeric_output_dims))
l = {}
for var in fgraph.shape_feature.shape_of:
l[var] = tuple(
sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var]
)
for var in shape_feature.shape_of:
l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var])
return l
......@@ -177,7 +176,7 @@ _SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature(
signature,
signature: str,
) -> tuple[
list[tuple[str, ...]], ...
]: # mypy doesn't know it's alwayl a length two tuple
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论