提交 63da6d16 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Fix types in tensor/random/utils.py and tensor/utils.py

上级 6a295b9f
...@@ -250,6 +250,7 @@ def core_VonMisesRV(op, node): ...@@ -250,6 +250,7 @@ def core_VonMisesRV(op, node):
@numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement) @numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement)
def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node): def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node):
assert isinstance(op.signature, str)
[core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1] [core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1]
core_shape_len = int(core_shape_len_sig) core_shape_len = int(core_shape_len_sig)
implicit_arange = op.ndims_params[0] == 0 implicit_arange = op.ndims_params[0] == 0
......
from collections.abc import Sequence from collections.abc import Callable, Sequence
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING
import numpy as np import numpy as np
...@@ -22,7 +22,9 @@ if TYPE_CHECKING: ...@@ -22,7 +22,9 @@ if TYPE_CHECKING:
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True): def params_broadcast_shapes(
param_shapes: Sequence, ndims_params: Sequence[int], use_pytensor: bool = True
) -> list[tuple[int, ...]]:
"""Broadcast parameters that have different dimensions. """Broadcast parameters that have different dimensions.
Parameters Parameters
...@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True): ...@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
Returns Returns
======= =======
bcast_shapes : list of ndarray bcast_shapes : list of tuples of ints
The broadcasted values of `params`. The broadcasted values of `params`.
""" """
max_fn = maximum if use_pytensor else max max_fn = maximum if use_pytensor else max
rev_extra_dims = [] rev_extra_dims: list[int] = []
for ndim_param, param_shape in zip(ndims_params, param_shapes): for ndim_param, param_shape in zip(ndims_params, param_shapes):
# We need this in order to use `len` # We need this in order to use `len`
param_shape = tuple(param_shape) param_shape = tuple(param_shape)
...@@ -71,7 +73,9 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True): ...@@ -71,7 +73,9 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
return bcast_shapes return bcast_shapes
def broadcast_params(params, ndims_params): def broadcast_params(
params: Sequence[np.ndarray | TensorVariable], ndims_params: Sequence[int]
) -> list[np.ndarray]:
"""Broadcast parameters that have different dimensions. """Broadcast parameters that have different dimensions.
>>> ndims_params = [1, 2] >>> ndims_params = [1, 2]
...@@ -215,7 +219,9 @@ class RandomStream: ...@@ -215,7 +219,9 @@ class RandomStream:
self, self,
seed: int | None = None, seed: int | None = None,
namespace: ModuleType | None = None, namespace: ModuleType | None = None,
rng_ctor: Literal[np.random.Generator] = np.random.default_rng, rng_ctor: Callable[
[np.random.SeedSequence], np.random.Generator
] = np.random.default_rng,
): ):
if namespace is None: if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self from pytensor.tensor.random import basic # pylint: disable=import-self
......
...@@ -6,10 +6,11 @@ import numpy as np ...@@ -6,10 +6,11 @@ import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
def hash_from_ndarray(data): def hash_from_ndarray(data) -> str:
""" """
Return a hash from an ndarray. Return a hash from an ndarray.
...@@ -36,7 +37,9 @@ def hash_from_ndarray(data): ...@@ -36,7 +37,9 @@ def hash_from_ndarray(data):
) )
def shape_of_variables(fgraph, input_shapes): def shape_of_variables(
fgraph: FunctionGraph, input_shapes
) -> dict[Variable, tuple[int, ...]]:
""" """
Compute the numeric shape of all intermediate variables given input shapes. Compute the numeric shape of all intermediate variables given input shapes.
...@@ -73,16 +76,14 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -73,16 +76,14 @@ def shape_of_variables(fgraph, input_shapes):
fgraph.attach_feature(ShapeFeature()) fgraph.attach_feature(ShapeFeature())
shape_feature = fgraph.shape_feature # type: ignore[attr-defined]
input_dims = [ input_dims = [
dimension dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp]
for inp in fgraph.inputs
for dimension in fgraph.shape_feature.shape_of[inp]
] ]
output_dims = [ output_dims = [
dimension dimension for shape in shape_feature.shape_of.values() for dimension in shape
for shape in fgraph.shape_feature.shape_of.values()
for dimension in shape
] ]
compute_shapes = pytensor.function(input_dims, output_dims) compute_shapes = pytensor.function(input_dims, output_dims)
...@@ -100,10 +101,8 @@ def shape_of_variables(fgraph, input_shapes): ...@@ -100,10 +101,8 @@ def shape_of_variables(fgraph, input_shapes):
sym_to_num_dict = dict(zip(output_dims, numeric_output_dims)) sym_to_num_dict = dict(zip(output_dims, numeric_output_dims))
l = {} l = {}
for var in fgraph.shape_feature.shape_of: for var in shape_feature.shape_of:
l[var] = tuple( l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var])
sym_to_num_dict[sym] for sym in fgraph.shape_feature.shape_of[var]
)
return l return l
...@@ -177,7 +176,7 @@ _SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$" ...@@ -177,7 +176,7 @@ _SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature( def _parse_gufunc_signature(
signature, signature: str,
) -> tuple[ ) -> tuple[
list[tuple[str, ...]], ... list[tuple[str, ...]], ...
]: # mypy doesn't know it's alwayl a length two tuple ]: # mypy doesn't know it's alwayl a length two tuple
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论