Unverified 提交 aad78d57 authored 作者: pre-commit-ci[bot]'s avatar pre-commit-ci[bot] 提交者: GitHub

[pre-commit.ci] pre-commit autoupdate (#666)

* [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.2.2 → v0.3.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.2.2...v0.3.2) - [github.com/pre-commit/mirrors-mypy: v1.8.0 → v1.9.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.8.0...v1.9.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
上级 4d365f68
...@@ -23,13 +23,13 @@ repos: ...@@ -23,13 +23,13 @@ repos:
)$ )$
- id: check-merge-conflict - id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2 rev: v0.3.2
hooks: hooks:
- id: ruff - id: ruff
args: ["--fix", "--output-format=full"] args: ["--fix", "--output-format=full"]
- id: ruff-format - id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0 rev: v1.9.0
hooks: hooks:
- id: mypy - id: mypy
language: python language: python
......
...@@ -18,7 +18,6 @@ To learn more, check out: ...@@ -18,7 +18,6 @@ To learn more, check out:
""" """
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
# Set a default logger. It is important to do this before importing some other # Set a default logger. It is important to do this before importing some other
......
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
This module contains housekeeping functions for cleaning/purging the "compiledir". This module contains housekeeping functions for cleaning/purging the "compiledir".
It is used by the "pytensor-cache" CLI tool, located in the /bin folder of the repository. It is used by the "pytensor-cache" CLI tool, located in the /bin folder of the repository.
""" """
import logging import logging
import os import os
import pickle import pickle
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Locking mechanism to ensure no two compilations occur simultaneously Locking mechanism to ensure no two compilations occur simultaneously
in the same compilation directory (which can cause crashes). in the same compilation directory (which can cause crashes).
""" """
import os import os
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
......
...@@ -5,7 +5,6 @@ TODO: add support for IfElse Op, LazyLinker, etc. ...@@ -5,7 +5,6 @@ TODO: add support for IfElse Op, LazyLinker, etc.
""" """
import copy import copy
import gc import gc
import logging import logging
...@@ -2317,10 +2316,7 @@ class DebugMode(Mode): ...@@ -2317,10 +2316,7 @@ class DebugMode(Mode):
raise ValueError("DebugMode has to check at least one of c and py code") raise ValueError("DebugMode has to check at least one of c and py code")
def __str__(self): def __str__(self):
return "DebugMode(linker={}, optimizer={})".format( return f"DebugMode(linker={self.provided_linker}, optimizer={self.provided_optimizer})"
self.provided_linker,
self.provided_optimizer,
)
register_mode("DEBUG_MODE", DebugMode(optimizer="fast_run")) register_mode("DEBUG_MODE", DebugMode(optimizer="fast_run"))
...@@ -35,8 +35,7 @@ def rebuild_collect_shared( ...@@ -35,8 +35,7 @@ def rebuild_collect_shared(
list[Variable], list[Variable],
list[SharedVariable], list[SharedVariable],
], ],
]: ]: ...
...
@overload @overload
...@@ -58,8 +57,7 @@ def rebuild_collect_shared( ...@@ -58,8 +57,7 @@ def rebuild_collect_shared(
list[Variable], list[Variable],
list[SharedVariable], list[SharedVariable],
], ],
]: ]: ...
...
@overload @overload
...@@ -81,8 +79,7 @@ def rebuild_collect_shared( ...@@ -81,8 +79,7 @@ def rebuild_collect_shared(
list[Variable], list[Variable],
list[SharedVariable], list[SharedVariable],
], ],
]: ]: ...
...
@overload @overload
...@@ -104,8 +101,7 @@ def rebuild_collect_shared( ...@@ -104,8 +101,7 @@ def rebuild_collect_shared(
list[Variable], list[Variable],
list[SharedVariable], list[SharedVariable],
], ],
]: ]: ...
...
def rebuild_collect_shared( def rebuild_collect_shared(
......
...@@ -3,7 +3,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`. ...@@ -3,7 +3,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`.
""" """
import logging import logging
from pytensor.link.basic import Container from pytensor.link.basic import Container
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
Author: Christof Angermueller <cangermueller@gmail.com> Author: Christof Angermueller <cangermueller@gmail.com>
""" """
import json import json
import os import os
import shutil import shutil
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Author: Christof Angermueller <cangermueller@gmail.com> Author: Christof Angermueller <cangermueller@gmail.com>
""" """
import os import os
from functools import reduce from functools import reduce
......
"""Core graph classes.""" """Core graph classes."""
import abc import abc
import warnings import warnings
from collections import deque from collections import deque
...@@ -643,8 +644,7 @@ class AtomicVariable(Variable[_TypeType, None]): ...@@ -643,8 +644,7 @@ class AtomicVariable(Variable[_TypeType, None]):
super().__init__(type=type, owner=None, index=None, name=name, **kwargs) super().__init__(type=type, owner=None, index=None, name=name, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def signature(self): def signature(self): ...
...
def merge_signature(self): def merge_signature(self):
return self.signature() return self.signature()
...@@ -1309,8 +1309,7 @@ def general_toposort( ...@@ -1309,8 +1309,7 @@ def general_toposort(
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]], compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]], deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]], clients: Optional[dict[T, list[T]]],
) -> list[T]: ) -> list[T]: ...
...
@overload @overload
...@@ -1320,8 +1319,7 @@ def general_toposort( ...@@ -1320,8 +1319,7 @@ def general_toposort(
compute_deps_cache: None, compute_deps_cache: None,
deps_cache: None, deps_cache: None,
clients: Optional[dict[T, list[T]]], clients: Optional[dict[T, list[T]]],
) -> list[T]: ) -> list[T]: ...
...
def general_toposort( def general_toposort(
......
...@@ -3,6 +3,7 @@ Classes and functions for validating graphs that contain view ...@@ -3,6 +3,7 @@ Classes and functions for validating graphs that contain view
and inplace operations. and inplace operations.
""" """
import itertools import itertools
from collections import OrderedDict, deque from collections import OrderedDict, deque
......
"""A container for specifying and manipulating a graph with distinct inputs and outputs.""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time import time
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
......
...@@ -40,8 +40,7 @@ def clone_replace( ...@@ -40,8 +40,7 @@ def clone_replace(
output: Sequence[Variable], output: Sequence[Variable],
replace: Optional[ReplaceTypes] = None, replace: Optional[ReplaceTypes] = None,
**rebuild_kwds, **rebuild_kwds,
) -> list[Variable]: ) -> list[Variable]: ...
...
@overload @overload
...@@ -51,8 +50,7 @@ def clone_replace( ...@@ -51,8 +50,7 @@ def clone_replace(
Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
] = None, ] = None,
**rebuild_kwds, **rebuild_kwds,
) -> Variable: ) -> Variable: ...
...
def clone_replace( def clone_replace(
...@@ -95,8 +93,7 @@ def graph_replace( ...@@ -95,8 +93,7 @@ def graph_replace(
replace: Optional[ReplaceTypes] = None, replace: Optional[ReplaceTypes] = None,
*, *,
strict=True, strict=True,
) -> Variable: ) -> Variable: ...
...
@overload @overload
...@@ -105,8 +102,7 @@ def graph_replace( ...@@ -105,8 +102,7 @@ def graph_replace(
replace: Optional[ReplaceTypes] = None, replace: Optional[ReplaceTypes] = None,
*, *,
strict=True, strict=True,
) -> list[Variable]: ) -> list[Variable]: ...
...
def graph_replace( def graph_replace(
...@@ -229,16 +225,14 @@ def _vectorize_not_needed(op, node, *batched_inputs): ...@@ -229,16 +225,14 @@ def _vectorize_not_needed(op, node, *batched_inputs):
def vectorize_graph( def vectorize_graph(
outputs: Variable, outputs: Variable,
replace: Mapping[Variable, Variable], replace: Mapping[Variable, Variable],
) -> Variable: ) -> Variable: ...
...
@overload @overload
def vectorize_graph( def vectorize_graph(
outputs: Sequence[Variable], outputs: Sequence[Variable],
replace: Mapping[Variable, Variable], replace: Mapping[Variable, Variable],
) -> Sequence[Variable]: ) -> Sequence[Variable]: ...
...
def vectorize_graph( def vectorize_graph(
......
"""This module defines the base classes for graph rewriting.""" """This module defines the base classes for graph rewriting."""
import abc import abc
import copy import copy
import functools import functools
...@@ -123,8 +124,7 @@ class GraphRewriter(Rewriter): ...@@ -123,8 +124,7 @@ class GraphRewriter(Rewriter):
"""Rewrite a `FunctionGraph`.""" """Rewrite a `FunctionGraph`."""
return self.rewrite(fgraph) return self.rewrite(fgraph)
def add_requirements(self, fgraph): def add_requirements(self, fgraph): ...
...
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, "name", None) name = getattr(self, "name", None)
...@@ -1683,9 +1683,8 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1683,9 +1683,8 @@ class PatternNodeRewriter(NodeRewriter):
else: else:
return str(pattern) return str(pattern)
return "{} -> {}".format( return (
pattern_to_str(self.in_pattern), f"{pattern_to_str(self.in_pattern)} -> {pattern_to_str(self.out_pattern)}"
pattern_to_str(self.out_pattern),
) )
def __repr__(self): def __repr__(self):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Defines Linkers that deal with C implementations. Defines Linkers that deal with C implementations.
""" """
import logging import logging
import sys import sys
from collections import defaultdict from collections import defaultdict
...@@ -1566,16 +1567,16 @@ class CLinker(Linker): ...@@ -1566,16 +1567,16 @@ class CLinker(Linker):
# Static methods that can run and destroy the struct built by # Static methods that can run and destroy the struct built by
# instantiate. # instantiate.
static = """ static = f"""
static int {struct_name}_executor({struct_name} *self) {{ static int {self.struct_name}_executor({self.struct_name} *self) {{
return self->run(); return self->run();
}} }}
static void {struct_name}_destructor(PyObject *capsule) {{ static void {self.struct_name}_destructor(PyObject *capsule) {{
{struct_name} *self = ({struct_name} *)PyCapsule_GetContext(capsule); {self.struct_name} *self = ({self.struct_name} *)PyCapsule_GetContext(capsule);
delete self; delete self;
}} }}
""".format(struct_name=self.struct_name) """
# We add all the support code, compile args, headers and libs we need. # We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply: for support_code in self.support_code() + self.c_support_code_apply:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Generate and compile C modules for Python. Generate and compile C modules for Python.
""" """
import atexit import atexit
import importlib import importlib
import logging import logging
......
...@@ -113,7 +113,6 @@ for more info about enumeration aliases). ...@@ -113,7 +113,6 @@ for more info about enumeration aliases).
""" """
import hashlib import hashlib
import re import re
......
...@@ -503,14 +503,8 @@ def raise_with_op( ...@@ -503,14 +503,8 @@ def raise_with_op(
detailed_err_msg += f", TotalSize: {item[3]} Byte(s)\n" detailed_err_msg += f", TotalSize: {item[3]} Byte(s)\n"
else: else:
detailed_err_msg += "\n" detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: {} Byte(s) {:.3f} GB\n".format( detailed_err_msg += f" TotalSize: {total_size} Byte(s) {total_size / 1024 / 1024 / 1024:.3f} GB\n"
total_size, detailed_err_msg += f" TotalSize inputs: {total_size_inputs} Byte(s) {total_size_inputs / 1024 / 1024 / 1024:.3f} GB\n"
total_size / 1024 / 1024 / 1024,
)
detailed_err_msg += " TotalSize inputs: {} Byte(s) {:.3f} GB\n".format(
total_size_inputs,
total_size_inputs / 1024 / 1024 / 1024,
)
else: else:
hints.append( hints.append(
......
...@@ -5,6 +5,7 @@ A VM is not actually different from a Linker, we just decided ...@@ -5,6 +5,7 @@ A VM is not actually different from a Linker, we just decided
VM was a better name at some point. VM was a better name at some point.
""" """
import platform import platform
import sys import sys
import time import time
......
...@@ -3,7 +3,6 @@ Function to detect memory sharing for ndarray AND sparse type. ...@@ -3,7 +3,6 @@ Function to detect memory sharing for ndarray AND sparse type.
numpy version support only ndarray. numpy version support only ndarray.
""" """
import numpy as np import numpy as np
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Helper function to safely convert an array to a new data type. Helper function to safely convert an array to a new data type.
""" """
import numpy as np import numpy as np
from pytensor.configdefaults import config from pytensor.configdefaults import config
......
...@@ -646,13 +646,7 @@ def _debugprint( ...@@ -646,13 +646,7 @@ def _debugprint(
tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100
print( print(
"{} --> {:8.2e}s {:4.1f}% {:8.2e}s {:4.1f}%".format( f"{var_output} --> {op_time:8.2e}s {op_time_percent:4.1f}% {tot_time:8.2e}s {tot_time_percent:4.1f}%",
var_output,
op_time,
op_time_percent,
tot_time,
tot_time_percent,
),
file=file, file=file,
) )
else: else:
...@@ -1157,14 +1151,14 @@ if use_ascii: ...@@ -1157,14 +1151,14 @@ if use_ascii:
epsilon="\\epsilon", epsilon="\\epsilon",
) )
else: else:
special = dict(middle_dot="\u00B7", big_sigma="\u03A3") special = dict(middle_dot="\u00b7", big_sigma="\u03a3")
greek = dict( greek = dict(
alpha="\u03B1", alpha="\u03b1",
beta="\u03B2", beta="\u03b2",
gamma="\u03B3", gamma="\u03b3",
delta="\u03B4", delta="\u03b4",
epsilon="\u03B5", epsilon="\u03b5",
) )
......
...@@ -273,7 +273,6 @@ def convert(x, dtype=None): ...@@ -273,7 +273,6 @@ def convert(x, dtype=None):
class ScalarType(CType, HasDataType, HasShape): class ScalarType(CType, HasDataType, HasShape):
""" """
Internal class, should not be used by clients. Internal class, should not be used by clients.
......
...@@ -33,7 +33,6 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``, ...@@ -33,7 +33,6 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
""" """
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
__authors__ = ( __authors__ = (
"Razvan Pascanu " "Razvan Pascanu "
......
...@@ -3368,13 +3368,7 @@ def profile_printer( ...@@ -3368,13 +3368,7 @@ def profile_printer(
total_scan_fct_time += scan_fct_time total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time total_scan_op_time += scan_op_time
print( print(
" {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format( f" {v:5.1f}s {scan_fct_time:5.1f}s {scan_op_time:5.1f}s {scan_fct_time / v * 100:5.1f}% {scan_op_time / v * 100:5.1f}%",
v,
scan_fct_time,
scan_op_time,
scan_fct_time / v * 100,
scan_op_time / v * 100,
),
node, node,
file=file, file=file,
) )
...@@ -3388,13 +3382,7 @@ def profile_printer( ...@@ -3388,13 +3382,7 @@ def profile_printer(
print(" No scan have its inner profile enabled.", file=file) print(" No scan have its inner profile enabled.", file=file)
else: else:
print( print(
"total {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format( f"total {total_super_scan_time:5.1f}s {total_scan_fct_time:5.1f}s {total_scan_op_time:5.1f}s {total_scan_fct_time / total_super_scan_time * 100:5.1f}% {total_scan_op_time / total_super_scan_time * 100:5.1f}%",
total_super_scan_time,
total_scan_fct_time,
total_scan_op_time,
total_scan_fct_time / total_super_scan_time * 100,
total_scan_op_time / total_super_scan_time * 100,
),
file=file, file=file,
) )
......
...@@ -5,6 +5,7 @@ To update the `Scan` Cython code you must ...@@ -5,6 +5,7 @@ To update the `Scan` Cython code you must
- update the version value in this file and in `scan_perform.pyx` - update the version value in this file and in `scan_perform.pyx`
""" """
from pytensor.scan.scan_perform import get_version, perform # noqa: F401 from pytensor.scan.scan_perform import get_version, perform # noqa: F401
......
...@@ -7,6 +7,7 @@ http://www-users.cs.umn.edu/~saad/software/SPARSKIT/paper.ps ...@@ -7,6 +7,7 @@ http://www-users.cs.umn.edu/~saad/software/SPARSKIT/paper.ps
TODO: Automatic methods for determining best sparse format? TODO: Automatic methods for determining best sparse format?
""" """
from typing import Literal from typing import Literal
from warnings import warn from warnings import warn
...@@ -486,13 +487,7 @@ class SparseConstant(TensorConstant, _sparse_py_operators): ...@@ -486,13 +487,7 @@ class SparseConstant(TensorConstant, _sparse_py_operators):
return SparseConstantSignature((self.type, self.data)) return SparseConstantSignature((self.type, self.data))
def __str__(self): def __str__(self):
return "{}{{{},{},shape={},nnz={}}}".format( return f"{self.__class__.__name__}{{{self.format},{self.dtype},shape={self.data.shape},nnz={self.data.nnz}}}"
self.__class__.__name__,
self.format,
self.dtype,
self.data.shape,
self.data.nnz,
)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
......
...@@ -767,10 +767,7 @@ class GemmRelated(COp): ...@@ -767,10 +767,7 @@ class GemmRelated(COp):
def build_gemm_call(self): def build_gemm_call(self):
if hasattr(self, "inplace"): if hasattr(self, "inplace"):
setup_z_Nz_Sz = "if(%(params)s->inplace){{{}}}else{{{}}}".format( setup_z_Nz_Sz = f"if(%(params)s->inplace){{{self.setup_z_Nz_Sz_inplace}}}else{{{self.setup_z_Nz_Sz_outplace}}}"
self.setup_z_Nz_Sz_inplace,
self.setup_z_Nz_Sz_outplace,
)
else: else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz setup_z_Nz_Sz = self.setup_z_Nz_Sz
......
""" Header text for the C and Fortran BLAS interfaces. """Header text for the C and Fortran BLAS interfaces.
There is no standard name or location for this header, so we just insert it There is no standard name or location for this header, so we just insert it
ourselves into the C code. ourselves into the C code.
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
Abstract conv interface Abstract conv interface
""" """
import logging import logging
import sys import sys
import warnings import warnings
...@@ -2175,8 +2174,8 @@ class BaseAbstractConv(Op): ...@@ -2175,8 +2174,8 @@ class BaseAbstractConv(Op):
) )
): ):
raise ValueError( raise ValueError(
"invalid border mode {}. The tuple can only contain integers " f"invalid border mode {border_mode}. The tuple can only contain integers "
" or pairs of integers".format(border_mode) " or pairs of integers"
) )
if isinstance(mode, tuple): if isinstance(mode, tuple):
if convdim != 2: if convdim != 2:
...@@ -2348,29 +2347,29 @@ class BaseAbstractConv(Op): ...@@ -2348,29 +2347,29 @@ class BaseAbstractConv(Op):
for n in range(output_channel_offset): for n in range(output_channel_offset):
for im0 in range(input_channel_offset): for im0 in range(input_channel_offset):
if unshared: if unshared:
out[ out[b, g * output_channel_offset + n, ...] += (
b, g * output_channel_offset + n, ... self.unshared2d(
] += self.unshared2d( img[b, g * input_channel_offset + im0, ...],
img[b, g * input_channel_offset + im0, ...], dilated_kern[
dilated_kern[ g * output_channel_offset + n, im0, ...
g * output_channel_offset + n, im0, ... ],
], out_shape[2:],
out_shape[2:], direction,
direction, )
) )
else: else:
# some cast generates a warning here # some cast generates a warning here
out[ out[b, g * output_channel_offset + n, ...] += (
b, g * output_channel_offset + n, ... _convolve2d(
] += _convolve2d( img[b, g * input_channel_offset + im0, ...],
img[b, g * input_channel_offset + im0, ...], dilated_kern[
dilated_kern[ g * output_channel_offset + n, im0, ...
g * output_channel_offset + n, im0, ... ],
], 1,
1, val,
val, bval,
bval, 0,
0, )
) )
elif self.convdim == 3: elif self.convdim == 3:
...@@ -2550,10 +2549,8 @@ class AbstractConv(BaseAbstractConv): ...@@ -2550,10 +2549,8 @@ class AbstractConv(BaseAbstractConv):
) )
if kern.shape[1 : 1 + self.convdim] != out_shape[2 : 2 + self.convdim]: if kern.shape[1 : 1 + self.convdim] != out_shape[2 : 2 + self.convdim]:
raise ValueError( raise ValueError(
"Kernel shape {} does not match " "computed output size {}".format( f"Kernel shape {kern.shape[1 : 1 + self.convdim]} does not match "
kern.shape[1 : 1 + self.convdim], f"computed output size {out_shape[2 : 2 + self.convdim]}"
out_shape[2 : 2 + self.convdim],
)
) )
if any(self.subsample[i] > 1 for i in range(self.convdim)): if any(self.subsample[i] > 1 for i in range(self.convdim)):
# Expand regions in kernel to correct for subsampling # Expand regions in kernel to correct for subsampling
...@@ -3236,10 +3233,8 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -3236,10 +3233,8 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if tuple(expected_topgrad_shape) != tuple(topgrad.shape): if tuple(expected_topgrad_shape) != tuple(topgrad.shape):
raise ValueError( raise ValueError(
"invalid input_shape for gradInputs: the given input_shape " "invalid input_shape for gradInputs: the given input_shape "
"would produce an output of shape {}, but the given topgrad " f"would produce an output of shape {tuple(expected_topgrad_shape)}, but the given topgrad "
"has shape {}".format( f"has shape {tuple(topgrad.shape)}"
tuple(expected_topgrad_shape), tuple(topgrad.shape)
)
) )
if any(self.subsample[i] > 1 for i in range(self.convdim)): if any(self.subsample[i] > 1 for i in range(self.convdim)):
new_shape = ( new_shape = (
......
...@@ -49,10 +49,8 @@ class LoadFromDisk(Op): ...@@ -49,10 +49,8 @@ class LoadFromDisk(Op):
out[0][0] = result out[0][0] = result
def __str__(self): def __str__(self):
return "Load{{dtype: {}, shape: {}, mmep: {}}}".format( return (
self.dtype, f"Load{{dtype: {self.dtype}, shape: {self.shape}, mmep: {self.mmap_mode}}}"
self.shape,
self.mmap_mode,
) )
......
""" Tensor optimizations addressing the ops in basic.py. """Tensor optimizations addressing the ops in basic.py.
Notes Notes
----- -----
......
...@@ -100,12 +100,7 @@ class SliceConstant(Constant): ...@@ -100,12 +100,7 @@ class SliceConstant(Constant):
return (SliceConstant, self.data.start, self.data.stop, self.data.step) return (SliceConstant, self.data.start, self.data.stop, self.data.step)
def __str__(self): def __str__(self):
return "{}{{{}, {}, {}}}".format( return f"{self.__class__.__name__}{{{self.data.start}, {self.data.stop}, {self.data.step}}}"
self.__class__.__name__,
self.data.start,
self.data.stop,
self.data.step,
)
SliceType.constant_type = SliceConstant SliceType.constant_type = SliceConstant
......
"""Defines Updates object for storing a (SharedVariable, new_value) mapping. """Defines Updates object for storing a (SharedVariable, new_value) mapping."""
"""
import logging import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
......
...@@ -8,6 +8,7 @@ Usage ...@@ -8,6 +8,7 @@ Usage
----- -----
python scripts/run_mypy.py [--verbose] python scripts/run_mypy.py [--verbose]
""" """
import argparse import argparse
import importlib import importlib
import os import os
......
...@@ -4,6 +4,7 @@ We don't have real tests for the cache, but it would be great to make them! ...@@ -4,6 +4,7 @@ We don't have real tests for the cache, but it would be great to make them!
But this one tests a current behavior that isn't good: the c_code isn't But this one tests a current behavior that isn't good: the c_code isn't
deterministic based on the input type and the op. deterministic based on the input type and the op.
""" """
import multiprocessing import multiprocessing
import os import os
import re import re
...@@ -369,8 +370,9 @@ def test_cache_race_condition(): ...@@ -369,8 +370,9 @@ def test_cache_race_condition():
# The module cache must (initially) be `None` for all processes so that # The module cache must (initially) be `None` for all processes so that
# `ModuleCache.refresh` is called # `ModuleCache.refresh` is called
with patch.object(compiledir_prop, "val", dir_name, create=True), patch.object( with (
pytensor.link.c.cmodule, "_module_cache", None patch.object(compiledir_prop, "val", dir_name, create=True),
patch.object(pytensor.link.c.cmodule, "_module_cache", None),
): ):
assert pytensor.config.compiledir == dir_name assert pytensor.config.compiledir == dir_name
......
""" """
Questions and notes about scan that should be answered : Questions and notes about scan that should be answered :
* Scan seems to do copies of every input variable. Is that needed? * Scan seems to do copies of every input variable. Is that needed?
answer : probably not, but it doesn't hurt also ( what we copy is answer : probably not, but it doesn't hurt also ( what we copy is
pytensor variables, which just cary information about the type / dimension pytensor variables, which just cary information about the type / dimension
of the data) of the data)
* There is some of scan functionality that is not well documented * There is some of scan functionality that is not well documented
""" """
import os import os
......
...@@ -160,13 +160,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): ...@@ -160,13 +160,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
padD = property(lambda self: self.pad[2]) padD = property(lambda self: self.pad[2])
def __str__(self): def __str__(self):
return "{}{{{}, {}, {}, {}}}".format( return f"{self.__class__.__name__}{{{self.border_mode}, {self.subsample!s}, {self.filter_dilation!s}, {self.num_groups!s}}}"
self.__class__.__name__,
self.border_mode,
str(self.subsample),
str(self.filter_dilation),
str(self.num_groups),
)
@staticmethod @staticmethod
def as_common_dtype(in1, in2): def as_common_dtype(in1, in2):
......
...@@ -109,8 +109,8 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -109,8 +109,8 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
border += ((int(mode), int(mode)),) border += ((int(mode), int(mode)),)
else: else:
raise ValueError( raise ValueError(
"invalid border mode {}. The tuple can only contain " f"invalid border mode {border_mode}. The tuple can only contain "
"integers or tuples of length 2".format(border_mode) "integers or tuples of length 2"
) )
border_mode = border border_mode = border
elif border_mode not in ("valid", "full", "half"): elif border_mode not in ("valid", "full", "half"):
...@@ -176,14 +176,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -176,14 +176,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
padW_r = property(lambda self: self.pad[1][1]) padW_r = property(lambda self: self.pad[1][1])
def __str__(self): def __str__(self):
return "{}{{{}, {}, {}, {} {}}}".format( return f"{self.__class__.__name__}{{{self.border_mode}, {self.subsample!s}, {self.filter_dilation!s}, {self.num_groups!s} {self.unshared!s}}}"
self.__class__.__name__,
self.border_mode,
str(self.subsample),
str(self.filter_dilation),
str(self.num_groups),
str(self.unshared),
)
@staticmethod @staticmethod
def as_common_dtype(in1, in2): def as_common_dtype(in1, in2):
......
...@@ -912,8 +912,9 @@ def test_infer_static_shape(): ...@@ -912,8 +912,9 @@ def test_infer_static_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_static_shape([constant(1.0)]) infer_static_shape([constant(1.0)])
with config.change_flags(exception_verbosity="high"), pytest.raises( with (
TypeError, match=r"A\. x" config.change_flags(exception_verbosity="high"),
pytest.raises(TypeError, match=r"A\. x"),
): ):
infer_static_shape([dscalar("x")]) infer_static_shape([dscalar("x")])
......
""" This file don't test everything. It only test one past crash error.""" """This file don't test everything. It only test one past crash error."""
import pytensor import pytensor
from pytensor import as_symbolic from pytensor import as_symbolic
......
...@@ -473,8 +473,9 @@ def makeTester( ...@@ -473,8 +473,9 @@ def makeTester(
f = inplace_func(inputrs, node.outputs, mode=mode, name="test_good") f = inplace_func(inputrs, node.outputs, mode=mode, name="test_good")
except Exception as exc: except Exception as exc:
err_msg = ( err_msg = (
"Test {}::{}: Error occurred while" " trying to make a Function" f"Test {self.op}::{testname}: Error occurred while"
).format(self.op, testname) " trying to make a Function"
)
exc.args += (err_msg,) exc.args += (err_msg,)
raise raise
if isinstance(self.expected, dict) and testname in self.expected: if isinstance(self.expected, dict) and testname in self.expected:
...@@ -513,29 +514,17 @@ def makeTester( ...@@ -513,29 +514,17 @@ def makeTester(
or not np.allclose(variable, expected, atol=eps, rtol=eps) or not np.allclose(variable, expected, atol=eps, rtol=eps)
) )
assert not condition, ( assert not condition, (
"Test {}::{}: Output {} gave the wrong" f"Test {self.op}::{testname}: Output {i} gave the wrong"
" value. With inputs {}, expected {} (dtype {})," f" value. With inputs {inputs}, expected {expected} (dtype {expected.dtype}),"
" got {} (dtype {}). eps={:f}" f" got {variable} (dtype {variable.dtype}). eps={eps:f}"
" np.allclose returns {} {}" f" np.allclose returns {np.allclose(variable, expected, atol=eps, rtol=eps)} {np.allclose(variable, expected)}"
).format(
self.op,
testname,
i,
inputs,
expected,
expected.dtype,
variable,
variable.dtype,
eps,
np.allclose(variable, expected, atol=eps, rtol=eps),
np.allclose(variable, expected),
) )
for description, check in self.checks.items(): for description, check in self.checks.items():
assert check(inputs, variables), ( assert check(inputs, variables), (
"Test {}::{}: Failed check: {} (inputs" f"Test {self.op}::{testname}: Failed check: {description} (inputs"
" were {}, outputs were {})" f" were {inputs}, outputs were {variables})"
).format(self.op, testname, description, inputs, variables) )
@pytest.mark.skipif(skip, reason="Skipped") @pytest.mark.skipif(skip, reason="Skipped")
def test_bad_build(self): def test_bad_build(self):
...@@ -569,8 +558,9 @@ def makeTester( ...@@ -569,8 +558,9 @@ def makeTester(
) )
except Exception as exc: except Exception as exc:
err_msg = ( err_msg = (
"Test {}::{}: Error occurred while trying" " to make a Function" f"Test {self.op}::{testname}: Error occurred while trying"
).format(self.op, testname) " to make a Function"
)
exc.args += (err_msg,) exc.args += (err_msg,)
raise raise
......
"""Test config options.""" """Test config options."""
import configparser as stdlib_configparser import configparser as stdlib_configparser
import io import io
import pickle import pickle
......
""" """
Tests of printing functionality Tests of printing functionality
""" """
import logging import logging
from io import StringIO from io import StringIO
from textwrap import dedent from textwrap import dedent
......
...@@ -11,7 +11,6 @@ the docstring of the functions: check_mat_rop_lop, check_rop_lop, ...@@ -11,7 +11,6 @@ the docstring of the functions: check_mat_rop_lop, check_rop_lop,
check_nondiff_rop, check_nondiff_rop,
""" """
import numpy as np import numpy as np
import pytest import pytest
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论