提交 e27a43a0 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Apply the ruff auto fixes

上级 207b5e03
......@@ -12,7 +12,7 @@ from pytensor.compile.profiling import ProfileStats
from pytensor.graph import Variable
__all__ = ["types", "pfunc"]
__all__ = ["pfunc", "types"]
__docformat__ = "restructuredtext en"
_logger = logging.getLogger("pytensor.compile.function")
......
......@@ -215,7 +215,7 @@ def add_supervisor_to_fgraph(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
spec.mutable or (has_destroy_handler and fgraph.has_destroyers([input]))
)
)
)
......
......@@ -481,7 +481,7 @@ class EnumType(CType, dict):
return tuple(sorted(self.aliases))
def __repr__(self):
names_to_aliases = {constant_name: "" for constant_name in self}
names_to_aliases = dict.fromkeys(self, "")
for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = f"({alias})"
args = ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self))
......
......@@ -9,7 +9,7 @@ class OrderedSet(MutableSet):
if iterable is None:
self.values = {}
else:
self.values = {value: None for value in iterable}
self.values = dict.fromkeys(iterable)
def __contains__(self, value) -> bool:
return value in self.values
......
......@@ -295,8 +295,10 @@ N.B.:
if hasattr(var.owner, "op"):
if (
isinstance(var.owner.op, HasInnerGraph)
or hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
or (
hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
)
) and var not in inner_graph_vars:
inner_graph_vars.append(var)
if print_op_info:
......@@ -675,8 +677,10 @@ def _debugprint(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
isinstance(in_var.owner.op, HasInnerGraph)
or hasattr(in_var.owner.op, "scalar_op")
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
or (
hasattr(in_var.owner.op, "scalar_op")
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
)
) and in_var not in inner_graph_ops:
inner_graph_ops.append(in_var)
......@@ -882,7 +886,9 @@ class OperatorPrinter(Printer):
max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs):
new_precedence = self.precedence
if self.assoc == "left" and i != 0 or self.assoc == "right" and i != max_i:
if (self.assoc == "left" and i != 0) or (
self.assoc == "right" and i != max_i
):
new_precedence += 1e-6
with set_precedence(pstate, new_precedence):
......
......@@ -219,7 +219,7 @@ class ScalarLoop(ScalarInnerGraphOp):
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
}
out_subd = {u: f"%(o{i})s" for i, u in enumerate(fgraph.outputs[:n_update])}
until_subd = {u: "until" for u in fgraph.outputs[n_update:]}
until_subd = dict.fromkeys(fgraph.outputs[n_update:], "until")
subd = {**carry_subd, **constant_subd, **until_subd}
for var in fgraph.variables:
......
......@@ -4577,73 +4577,73 @@ def ix_(*args):
__all__ = [
"take_along_axis",
"expand_dims",
"atleast_Nd",
"alloc",
"arange",
"as_tensor",
"as_tensor_variable",
"atleast_1d",
"atleast_2d",
"atleast_3d",
"atleast_Nd",
"cast",
"choose",
"swapaxes",
"moveaxis",
"stacklists",
"concatenate",
"constant",
"default",
"diag",
"diagonal",
"inverse_permutation",
"permute_row_elements",
"mgrid",
"ogrid",
"arange",
"tile",
"empty",
"empty_like",
"expand_dims",
"extract_diag",
"eye",
"fill",
"flatnonzero",
"flatten",
"is_flat",
"vertical_stack",
"horizontal_stack",
"full",
"full_like",
"get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"get_vector_length",
"concatenate",
"stack",
"roll",
"join",
"split",
"transpose",
"matrix_transpose",
"default",
"tensor_copy",
"horizontal_stack",
"identity",
"transfer",
"alloc",
"identity_like",
"eye",
"triu",
"tril",
"tri",
"nonzero_values",
"flatnonzero",
"inverse_permutation",
"is_flat",
"join",
"matrix_transpose",
"mgrid",
"moveaxis",
"nonzero",
"nonzero_values",
"ogrid",
"ones",
"zeros",
"zeros_like",
"ones_like",
"fill",
"permute_row_elements",
"roll",
"scalar_from_tensor",
"second",
"where",
"split",
"stack",
"stacklists",
"swapaxes",
"switch",
"cast",
"scalar_from_tensor",
"take_along_axis",
"tensor_copy",
"tensor_from_scalar",
"get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"constant",
"as_tensor_variable",
"as_tensor",
"extract_diag",
"full",
"full_like",
"empty",
"empty_like",
"tile",
"trace",
"transfer",
"transpose",
"tri",
"tril",
"tril_indices",
"tril_indices_from",
"triu",
"triu_indices",
"triu_indices_from",
"vertical_stack",
"where",
"zeros",
"zeros_like",
]
......@@ -821,7 +821,7 @@ class Elemwise(OpenMPOp):
# for each input:
# same as range(ndim), but with 'x' at all broadcastable positions
orders = [
[s == 1 and "x" or i for i, s in enumerate(input.type.shape)]
[(s == 1 and "x") or i for i, s in enumerate(input.type.shape)]
for input in inputs
]
......
......@@ -1558,7 +1558,7 @@ def broadcast_shape_iter(
(one,) * (max_dims - len(a))
+ tuple(
one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
if sh == 1 or (isinstance(sh, Constant) and sh.value == 1)
else (ps.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a
)
......@@ -2067,25 +2067,25 @@ def concat_with_broadcast(tensor_list, axis=0):
__all__ = [
"searchsorted",
"cumsum",
"cumprod",
"diff",
"bincount",
"squeeze",
"compress",
"repeat",
"bartlett",
"fill_diagonal",
"fill_diagonal_offset",
"unique",
"unravel_index",
"ravel_multi_index",
"bincount",
"broadcast_arrays",
"broadcast_shape",
"broadcast_to",
"compress",
"concat_with_broadcast",
"cumprod",
"cumsum",
"diff",
"fill_diagonal",
"fill_diagonal_offset",
"geomspace",
"logspace",
"linspace",
"broadcast_arrays",
"logspace",
"ravel_multi_index",
"repeat",
"searchsorted",
"squeeze",
"unique",
"unravel_index",
]
......@@ -4162,154 +4162,154 @@ equal = eq
not_equal = neq
__all__ = [
"max_and_argmax",
"max",
"matmul",
"vecdot",
"matvec",
"vecmat",
"argmax",
"min",
"argmin",
"smallest",
"largest",
"lt",
"less",
"gt",
"greater",
"le",
"less_equal",
"ge",
"greater_equal",
"eq",
"equal",
"neq",
"not_equal",
"isnan",
"isinf",
"isposinf",
"isneginf",
"abs",
"add",
"all",
"allclose",
"isclose",
"and_",
"angle",
"any",
"arccos",
"arccosh",
"arcsin",
"arcsinh",
"arctan",
"arctan2",
"arctanh",
"argmax",
"argmin",
"betainc",
"betaincinv",
"bitwise_and",
"or_",
"bitwise_not",
"bitwise_or",
"xor",
"bitwise_xor",
"invert",
"bitwise_not",
"abs",
"exp",
"exp2",
"expm1",
"neg",
"reciprocal",
"log",
"log2",
"log10",
"log1p",
"sgn",
"sign",
"ceil",
"floor",
"trunc",
"iround",
"round",
"round_half_to_even",
"round_half_away_from_zero",
"sqr",
"square",
"cov",
"sqrt",
"deg2rad",
"rad2deg",
"ceil_intdiv",
"chi2sf",
"clip",
"complex",
"complex_from_polar",
"conj",
"conjugate",
"cos",
"arccos",
"sin",
"arcsin",
"tan",
"arctan",
"arctan2",
"cosh",
"arccosh",
"sinh",
"arcsinh",
"tanh",
"arctanh",
"cov",
"deg2rad",
"dense_dot",
"digamma",
"divmod",
"dot",
"eq",
"equal",
"erf",
"erfc",
"erfcinv",
"erfcx",
"erfinv",
"erfcinv",
"owens_t",
"exp",
"exp2",
"expit",
"expm1",
"floor",
"floor_div",
"gamma",
"gammaln",
"psi",
"digamma",
"tri_gamma",
"polygamma",
"chi2sf",
"gammainc",
"gammaincc",
"gammau",
"gammal",
"gammaincinv",
"gammainccinv",
"j0",
"j1",
"jv",
"gammaincinv",
"gammal",
"gammaln",
"gammau",
"ge",
"greater",
"greater_equal",
"gt",
"hyp2f1",
"i0",
"i1",
"imag",
"int_div",
"invert",
"iround",
"isclose",
"isinf",
"isnan",
"isneginf",
"isposinf",
"iv",
"ive",
"j0",
"j1",
"jv",
"kn",
"kv",
"kve",
"sigmoid",
"expit",
"softplus",
"log1pexp",
"largest",
"le",
"less",
"less_equal",
"log",
"log1mexp",
"betainc",
"betaincinv",
"real",
"imag",
"angle",
"complex",
"conj",
"conjugate",
"complex_from_polar",
"sum",
"prod",
"log1p",
"log1pexp",
"log2",
"log10",
"logaddexp",
"logsumexp",
"lt",
"matmul",
"matvec",
"max",
"max_and_argmax",
"maximum",
"mean",
"median",
"var",
"std",
"std",
"maximum",
"min",
"minimum",
"divmod",
"add",
"sub",
"mul",
"true_div",
"int_div",
"floor_div",
"ceil_intdiv",
"mod",
"pow",
"clip",
"dot",
"dense_dot",
"tensordot",
"mul",
"nan_to_num",
"neg",
"neq",
"not_equal",
"or_",
"outer",
"any",
"all",
"ptp",
"owens_t",
"polygamma",
"pow",
"power",
"logaddexp",
"logsumexp",
"hyp2f1",
"nan_to_num",
"prod",
"psi",
"ptp",
"rad2deg",
"real",
"reciprocal",
"round",
"round_half_away_from_zero",
"round_half_to_even",
"sgn",
"sigmoid",
"sign",
"sin",
"sinh",
"smallest",
"softplus",
"sqr",
"sqrt",
"square",
"std",
"std",
"sub",
"sum",
"tan",
"tanh",
"tensordot",
"tri_gamma",
"true_div",
"trunc",
"var",
"vecdot",
"vecmat",
"xor",
]
......@@ -1114,19 +1114,19 @@ def kron(a, b):
__all__ = [
"pinv",
"inv",
"trace",
"matrix_dot",
"det",
"eig",
"eigh",
"svd",
"inv",
"kron",
"lstsq",
"matrix_dot",
"matrix_power",
"norm",
"pinv",
"slogdet",
"svd",
"tensorinv",
"tensorsolve",
"kron",
"trace",
]
......@@ -999,4 +999,4 @@ def root(
return solution, success
__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"]
__all__ = ["minimize", "minimize_scalar", "root", "root_scalar"]
......@@ -689,4 +689,4 @@ def pad(
return cast(TensorVariable, op)
__all__ = ["pad", "flip"]
__all__ = ["flip", "pad"]
......@@ -2132,42 +2132,42 @@ def permutation(x, **kwargs):
__all__ = [
"permutation",
"choice",
"integers",
"categorical",
"multinomial",
"bernoulli",
"beta",
"betabinom",
"nbinom",
"binomial",
"laplace",
"bernoulli",
"truncexpon",
"wald",
"invgamma",
"halfcauchy",
"categorical",
"cauchy",
"hypergeometric",
"geometric",
"poisson",
"chisquare",
"choice",
"dirichlet",
"multivariate_normal",
"vonmises",
"logistic",
"weibull",
"exponential",
"gumbel",
"pareto",
"chisquare",
"gamma",
"lognormal",
"gengamma",
"geometric",
"gumbel",
"halfcauchy",
"halfnormal",
"hypergeometric",
"integers",
"invgamma",
"laplace",
"logistic",
"lognormal",
"multinomial",
"multivariate_normal",
"nbinom",
"negative_binomial",
"normal",
"beta",
"triangular",
"uniform",
"pareto",
"permutation",
"poisson",
"standard_normal",
"negative_binomial",
"gengamma",
"t",
"triangular",
"truncexpon",
"uniform",
"vonmises",
"wald",
"weibull",
]
......@@ -392,7 +392,7 @@ def local_dimshuffle_lift(fgraph, node):
ret = inode.op(*new_inputs, return_list=True)
return ret
if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order]
new_order = [(x == "x" and "x") or inode.op.new_order[x] for x in new_order]
inp = inode.inputs[0]
if is_dimshuffle_useless(new_order, inp):
......
......@@ -350,8 +350,10 @@ def local_lift_through_linalg(
outer_op = node.op
if y.owner and (
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
(
isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
)
or isinstance(y.owner.op, KroneckerProduct)
):
input_matrices = y.owner.inputs
......
......@@ -448,8 +448,7 @@ class ShapeFeature(Feature):
assert all(
(
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
and other_r.type.shape[i] != 1
or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1)
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
......
......@@ -2088,18 +2088,18 @@ def qr(
__all__ = [
"block_diag",
"cho_solve",
"cholesky",
"solve",
"eigvalsh",
"expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
"cho_solve",
"lu",
"lu_factor",
"lu_solve",
"qr",
"solve",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_discrete_lyapunov",
"solve_triangular",
]
......@@ -812,11 +812,11 @@ def betaln(a, b):
__all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
"logit",
"beta",
"betaln",
"factorial",
"log_softmax",
"logit",
"poch",
"softmax",
]
......@@ -140,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"]
__all__ = ["NoneConst", "NoneSliceConst", "make_slice", "none_type_t", "slicetype"]
......@@ -14,16 +14,16 @@ import numpy as np
__all__ = [
"get_unbound_function",
"maybe_add_to_os_environ_pathlist",
"subprocess_Popen",
"call_subprocess_Popen",
"output_subprocess_Popen",
"LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH",
"NPY_RAVEL_AXIS",
"NDARRAY_C_VERSION",
"NPY_RAVEL_AXIS",
"PYTHON_INT_BITWIDTH",
"NoDuplicateOptWarningFilter",
"call_subprocess_Popen",
"get_unbound_function",
"maybe_add_to_os_environ_pathlist",
"output_subprocess_Popen",
"subprocess_Popen",
]
......
......@@ -491,7 +491,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
if isinstance(dim, str):
dims_dict = {dim: 1}
elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim}
dims_dict = dict.fromkeys(dim, 1)
elif isinstance(dim, dict):
dims_dict = {}
for name, val in dim.items():
......
......@@ -651,10 +651,10 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
)
else:
# Default to 5 for head and tail
indexers = {dim: 5 for dim in self.type.dims}
indexers = dict.fromkeys(self.type.dims, 5)
elif not isinstance(indexers, dict):
indexers = {dim: indexers for dim in self.type.dims}
indexers = dict.fromkeys(self.type.dims, indexers)
if kind == "head":
indices = {dim: slice(None, value) for dim, value in indexers.items()}
......
......@@ -2651,9 +2651,8 @@ class TestArithmeticCast:
# behavior.
return
if (
{a_type, b_type} == {"complex128", "float32"}
or {a_type, b_type} == {"complex128", "float16"}
if {a_type, b_type} == {"complex128", "float32"} or (
{a_type, b_type} == {"complex128", "float16"}
and set(combo) == {"scalar", "array"}
and pytensor_dtype == "complex128"
and numpy_dtype == "complex64"
......
......@@ -48,11 +48,8 @@ class MyOp(Op):
return self.name
def __eq__(self, other):
return (
self is other
or isinstance(other, MyOp)
and self.x is not None
and self.x == other.x
return self is other or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
)
def __hash__(self):
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
# ruff: noqa: E402
import pytest
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论