提交 f2281495 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Fix mypy errors on `main`

上级 7092f551
...@@ -9,8 +9,10 @@ from pytensor.link.basic import JITLinker ...@@ -9,8 +9,10 @@ from pytensor.link.basic import JITLinker
class JAXLinker(JITLinker): class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX.""" """A `Linker` that JIT-compiles NumPy-based operations using JAX."""
scalar_shape_inputs: tuple[int, ...]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked] self.scalar_shape_inputs = ()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
......
...@@ -517,9 +517,9 @@ def make_loop_call( ...@@ -517,9 +517,9 @@ def make_loop_call(
output_slices = [] output_slices = []
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True): for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
core_ndim = output_type.ndim - len(bc) core_ndim = output_type.ndim - len(bc)
size_type = output.shape.type.element # type: ignore size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue]
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue]
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue]
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [ idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
zero zero
...@@ -527,7 +527,7 @@ def make_loop_call( ...@@ -527,7 +527,7 @@ def make_loop_call(
ptr = cgutils.get_item_pointer2( ptr = cgutils.get_item_pointer2(
context, context,
builder, builder,
output.data, # type:ignore output.data,
output_shape, output_shape,
output_strides, output_strides,
output_type.layout, output_type.layout,
......
...@@ -41,7 +41,7 @@ using_numpy_2 = numpy_version >= "2.0.0rc1" ...@@ -41,7 +41,7 @@ using_numpy_2 = numpy_version >= "2.0.0rc1"
if using_numpy_2: if using_numpy_2:
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
else: else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
......
...@@ -109,7 +109,7 @@ def safe_new( ...@@ -109,7 +109,7 @@ def safe_new(
except TestValueError: except TestValueError:
pass pass
return nw_x return type_cast(Variable, nw_x)
class until: class until:
......
...@@ -597,10 +597,14 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -597,10 +597,14 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
# Numpy einsum_path requires arrays even though only the shapes matter # Numpy einsum_path requires arrays even though only the shapes matter
# It's not trivial to duck-type our way around because of internal call to `asanyarray` # It's not trivial to duck-type our way around because of internal call to `asanyarray`
*[np.empty(shape) for shape in shapes], *[np.empty(shape) for shape in shapes],
einsum_call=True, # Not part of public API # einsum_call is not part of public API
einsum_call=True, # type: ignore[arg-type]
optimize="optimal", optimize="optimal",
) # type: ignore )
np_path = tuple(contraction[0] for contraction in contraction_list) np_path: PATH | tuple[tuple[int, ...]] = tuple(
contraction[0] # type: ignore[misc]
for contraction in contraction_list
)
if len(np_path) == 1 and len(np_path[0]) > 2: if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
...@@ -610,7 +614,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar ...@@ -610,7 +614,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
subscripts, tensor_operands, path subscripts, tensor_operands, path
) )
else: else:
path = np_path path = cast(PATH, np_path)
optimized = True optimized = True
......
...@@ -53,10 +53,10 @@ def introduce_explicit_core_shape_rv(fgraph, node): ...@@ -53,10 +53,10 @@ def introduce_explicit_core_shape_rv(fgraph, node):
# ← dirichlet_rv{"(a)->(a)"}.1 [id F] # ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ··· # └─ ···
""" """
op: RandomVariable = node.op # type: ignore[annotation-unchecked] op: RandomVariable = node.op
next_rng, rv = node.outputs next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature: if shape_feature:
core_shape = [ core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
......
...@@ -102,7 +102,7 @@ def local_blockwise_alloc(fgraph, node): ...@@ -102,7 +102,7 @@ def local_blockwise_alloc(fgraph, node):
This is critical to remove many unnecessary Blockwise, or to reduce the work done by it This is critical to remove many unnecessary Blockwise, or to reduce the work done by it
""" """
op: Blockwise = node.op # type: ignore op: Blockwise = node.op
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
if not batch_ndim: if not batch_ndim:
......
...@@ -65,10 +65,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): ...@@ -65,10 +65,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6 # [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
# └─ ··· # └─ ···
""" """
op: Blockwise = node.op # type: ignore[annotation-unchecked] op: Blockwise = node.op
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked] shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature: if shape_feature:
core_shapes = [ core_shapes = [
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论