提交 f2281495 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Fix mypy errors on `main`

上级 7092f551
......@@ -9,8 +9,10 @@ from pytensor.link.basic import JITLinker
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
scalar_shape_inputs: tuple[int, ...]
def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
self.scalar_shape_inputs = ()
super().__init__(*args, **kwargs)
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
......
......@@ -517,9 +517,9 @@ def make_loop_call(
output_slices = []
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
core_ndim = output_type.ndim - len(bc)
size_type = output.shape.type.element # type: ignore
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue]
output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue]
output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue]
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
zero
......@@ -527,7 +527,7 @@ def make_loop_call(
ptr = cgutils.get_item_pointer2(
context,
builder,
output.data, # type:ignore
output.data,
output_shape,
output_strides,
output_type.layout,
......
......@@ -41,7 +41,7 @@ using_numpy_2 = numpy_version >= "2.0.0rc1"
if using_numpy_2:
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version()
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
......
......@@ -109,7 +109,7 @@ def safe_new(
except TestValueError:
pass
return nw_x
return type_cast(Variable, nw_x)
class until:
......
......@@ -597,10 +597,14 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
# Numpy einsum_path requires arrays even though only the shapes matter
# It's not trivial to duck-type our way around because of internal call to `asanyarray`
*[np.empty(shape) for shape in shapes],
einsum_call=True, # Not part of public API
# einsum_call is not part of public API
einsum_call=True, # type: ignore[arg-type]
optimize="optimal",
) # type: ignore
np_path = tuple(contraction[0] for contraction in contraction_list)
)
np_path: PATH | tuple[tuple[int, ...]] = tuple(
contraction[0] # type: ignore[misc]
for contraction in contraction_list
)
if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
......@@ -610,7 +614,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
subscripts, tensor_operands, path
)
else:
path = np_path
path = cast(PATH, np_path)
optimized = True
......
......@@ -53,10 +53,10 @@ def introduce_explicit_core_shape_rv(fgraph, node):
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
op: RandomVariable = node.op
next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
......
......@@ -102,7 +102,7 @@ def local_blockwise_alloc(fgraph, node):
This is critical to remove many unnecessary Blockwise, or to reduce the work done by it
"""
op: Blockwise = node.op # type: ignore
op: Blockwise = node.op
batch_ndim = op.batch_ndim(node)
if not batch_ndim:
......
......@@ -65,10 +65,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
# └─ ···
"""
op: Blockwise = node.op # type: ignore[annotation-unchecked]
op: Blockwise = node.op
batch_ndim = op.batch_ndim(node)
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature:
core_shapes = [
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论