提交 e27a43a0 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Apply the ruff auto fixes

上级 207b5e03
...@@ -12,7 +12,7 @@ from pytensor.compile.profiling import ProfileStats ...@@ -12,7 +12,7 @@ from pytensor.compile.profiling import ProfileStats
from pytensor.graph import Variable from pytensor.graph import Variable
__all__ = ["types", "pfunc"] __all__ = ["pfunc", "types"]
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
_logger = logging.getLogger("pytensor.compile.function") _logger = logging.getLogger("pytensor.compile.function")
......
...@@ -215,7 +215,7 @@ def add_supervisor_to_fgraph( ...@@ -215,7 +215,7 @@ def add_supervisor_to_fgraph(
input input
for spec, input in zip(input_specs, fgraph.inputs, strict=True) for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not ( if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input]) spec.mutable or (has_destroy_handler and fgraph.has_destroyers([input]))
) )
) )
) )
......
...@@ -481,7 +481,7 @@ class EnumType(CType, dict): ...@@ -481,7 +481,7 @@ class EnumType(CType, dict):
return tuple(sorted(self.aliases)) return tuple(sorted(self.aliases))
def __repr__(self): def __repr__(self):
names_to_aliases = {constant_name: "" for constant_name in self} names_to_aliases = dict.fromkeys(self, "")
for alias in self.aliases: for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = f"({alias})" names_to_aliases[self.aliases[alias]] = f"({alias})"
args = ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self)) args = ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self))
......
...@@ -9,7 +9,7 @@ class OrderedSet(MutableSet): ...@@ -9,7 +9,7 @@ class OrderedSet(MutableSet):
if iterable is None: if iterable is None:
self.values = {} self.values = {}
else: else:
self.values = {value: None for value in iterable} self.values = dict.fromkeys(iterable)
def __contains__(self, value) -> bool: def __contains__(self, value) -> bool:
return value in self.values return value in self.values
......
...@@ -295,8 +295,10 @@ N.B.: ...@@ -295,8 +295,10 @@ N.B.:
if hasattr(var.owner, "op"): if hasattr(var.owner, "op"):
if ( if (
isinstance(var.owner.op, HasInnerGraph) isinstance(var.owner.op, HasInnerGraph)
or hasattr(var.owner.op, "scalar_op") or (
and isinstance(var.owner.op.scalar_op, HasInnerGraph) hasattr(var.owner.op, "scalar_op")
and isinstance(var.owner.op.scalar_op, HasInnerGraph)
)
) and var not in inner_graph_vars: ) and var not in inner_graph_vars:
inner_graph_vars.append(var) inner_graph_vars.append(var)
if print_op_info: if print_op_info:
...@@ -675,8 +677,10 @@ def _debugprint( ...@@ -675,8 +677,10 @@ def _debugprint(
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"): if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if ( if (
isinstance(in_var.owner.op, HasInnerGraph) isinstance(in_var.owner.op, HasInnerGraph)
or hasattr(in_var.owner.op, "scalar_op") or (
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph) hasattr(in_var.owner.op, "scalar_op")
and isinstance(in_var.owner.op.scalar_op, HasInnerGraph)
)
) and in_var not in inner_graph_ops: ) and in_var not in inner_graph_ops:
inner_graph_ops.append(in_var) inner_graph_ops.append(in_var)
...@@ -882,7 +886,9 @@ class OperatorPrinter(Printer): ...@@ -882,7 +886,9 @@ class OperatorPrinter(Printer):
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
new_precedence = self.precedence new_precedence = self.precedence
if self.assoc == "left" and i != 0 or self.assoc == "right" and i != max_i: if (self.assoc == "left" and i != 0) or (
self.assoc == "right" and i != max_i
):
new_precedence += 1e-6 new_precedence += 1e-6
with set_precedence(pstate, new_precedence): with set_precedence(pstate, new_precedence):
......
...@@ -219,7 +219,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -219,7 +219,7 @@ class ScalarLoop(ScalarInnerGraphOp):
for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1) for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1)
} }
out_subd = {u: f"%(o{i})s" for i, u in enumerate(fgraph.outputs[:n_update])} out_subd = {u: f"%(o{i})s" for i, u in enumerate(fgraph.outputs[:n_update])}
until_subd = {u: "until" for u in fgraph.outputs[n_update:]} until_subd = dict.fromkeys(fgraph.outputs[n_update:], "until")
subd = {**carry_subd, **constant_subd, **until_subd} subd = {**carry_subd, **constant_subd, **until_subd}
for var in fgraph.variables: for var in fgraph.variables:
......
...@@ -4577,73 +4577,73 @@ def ix_(*args): ...@@ -4577,73 +4577,73 @@ def ix_(*args):
__all__ = [ __all__ = [
"take_along_axis", "alloc",
"expand_dims", "arange",
"atleast_Nd", "as_tensor",
"as_tensor_variable",
"atleast_1d", "atleast_1d",
"atleast_2d", "atleast_2d",
"atleast_3d", "atleast_3d",
"atleast_Nd",
"cast",
"choose", "choose",
"swapaxes", "concatenate",
"moveaxis", "constant",
"stacklists", "default",
"diag", "diag",
"diagonal", "diagonal",
"inverse_permutation", "empty",
"permute_row_elements", "empty_like",
"mgrid", "expand_dims",
"ogrid", "extract_diag",
"arange", "eye",
"tile", "fill",
"flatnonzero",
"flatten", "flatten",
"is_flat", "full",
"vertical_stack", "full_like",
"horizontal_stack", "get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"get_vector_length", "get_vector_length",
"concatenate", "horizontal_stack",
"stack",
"roll",
"join",
"split",
"transpose",
"matrix_transpose",
"default",
"tensor_copy",
"identity", "identity",
"transfer",
"alloc",
"identity_like", "identity_like",
"eye", "inverse_permutation",
"triu", "is_flat",
"tril", "join",
"tri", "matrix_transpose",
"nonzero_values", "mgrid",
"flatnonzero", "moveaxis",
"nonzero", "nonzero",
"nonzero_values",
"ogrid",
"ones", "ones",
"zeros",
"zeros_like",
"ones_like", "ones_like",
"fill", "permute_row_elements",
"roll",
"scalar_from_tensor",
"second", "second",
"where", "split",
"stack",
"stacklists",
"swapaxes",
"switch", "switch",
"cast", "take_along_axis",
"scalar_from_tensor", "tensor_copy",
"tensor_from_scalar", "tensor_from_scalar",
"get_scalar_constant_value", "tile",
"get_underlying_scalar_constant_value",
"constant",
"as_tensor_variable",
"as_tensor",
"extract_diag",
"full",
"full_like",
"empty",
"empty_like",
"trace", "trace",
"transfer",
"transpose",
"tri",
"tril",
"tril_indices", "tril_indices",
"tril_indices_from", "tril_indices_from",
"triu",
"triu_indices", "triu_indices",
"triu_indices_from", "triu_indices_from",
"vertical_stack",
"where",
"zeros",
"zeros_like",
] ]
...@@ -821,7 +821,7 @@ class Elemwise(OpenMPOp): ...@@ -821,7 +821,7 @@ class Elemwise(OpenMPOp):
# for each input: # for each input:
# same as range(ndim), but with 'x' at all broadcastable positions # same as range(ndim), but with 'x' at all broadcastable positions
orders = [ orders = [
[s == 1 and "x" or i for i, s in enumerate(input.type.shape)] [(s == 1 and "x") or i for i, s in enumerate(input.type.shape)]
for input in inputs for input in inputs
] ]
......
...@@ -1558,7 +1558,7 @@ def broadcast_shape_iter( ...@@ -1558,7 +1558,7 @@ def broadcast_shape_iter(
(one,) * (max_dims - len(a)) (one,) * (max_dims - len(a))
+ tuple( + tuple(
one one
if sh == 1 or isinstance(sh, Constant) and sh.value == 1 if sh == 1 or (isinstance(sh, Constant) and sh.value == 1)
else (ps.as_scalar(sh) if not isinstance(sh, Variable) else sh) else (ps.as_scalar(sh) if not isinstance(sh, Variable) else sh)
for sh in a for sh in a
) )
...@@ -2067,25 +2067,25 @@ def concat_with_broadcast(tensor_list, axis=0): ...@@ -2067,25 +2067,25 @@ def concat_with_broadcast(tensor_list, axis=0):
__all__ = [ __all__ = [
"searchsorted",
"cumsum",
"cumprod",
"diff",
"bincount",
"squeeze",
"compress",
"repeat",
"bartlett", "bartlett",
"fill_diagonal", "bincount",
"fill_diagonal_offset", "broadcast_arrays",
"unique",
"unravel_index",
"ravel_multi_index",
"broadcast_shape", "broadcast_shape",
"broadcast_to", "broadcast_to",
"compress",
"concat_with_broadcast", "concat_with_broadcast",
"cumprod",
"cumsum",
"diff",
"fill_diagonal",
"fill_diagonal_offset",
"geomspace", "geomspace",
"logspace",
"linspace", "linspace",
"broadcast_arrays", "logspace",
"ravel_multi_index",
"repeat",
"searchsorted",
"squeeze",
"unique",
"unravel_index",
] ]
...@@ -4162,154 +4162,154 @@ equal = eq ...@@ -4162,154 +4162,154 @@ equal = eq
not_equal = neq not_equal = neq
__all__ = [ __all__ = [
"max_and_argmax", "abs",
"max", "add",
"matmul", "all",
"vecdot",
"matvec",
"vecmat",
"argmax",
"min",
"argmin",
"smallest",
"largest",
"lt",
"less",
"gt",
"greater",
"le",
"less_equal",
"ge",
"greater_equal",
"eq",
"equal",
"neq",
"not_equal",
"isnan",
"isinf",
"isposinf",
"isneginf",
"allclose", "allclose",
"isclose",
"and_", "and_",
"angle",
"any",
"arccos",
"arccosh",
"arcsin",
"arcsinh",
"arctan",
"arctan2",
"arctanh",
"argmax",
"argmin",
"betainc",
"betaincinv",
"bitwise_and", "bitwise_and",
"or_", "bitwise_not",
"bitwise_or", "bitwise_or",
"xor",
"bitwise_xor", "bitwise_xor",
"invert",
"bitwise_not",
"abs",
"exp",
"exp2",
"expm1",
"neg",
"reciprocal",
"log",
"log2",
"log10",
"log1p",
"sgn",
"sign",
"ceil", "ceil",
"floor", "ceil_intdiv",
"trunc", "chi2sf",
"iround", "clip",
"round", "complex",
"round_half_to_even", "complex_from_polar",
"round_half_away_from_zero", "conj",
"sqr", "conjugate",
"square",
"cov",
"sqrt",
"deg2rad",
"rad2deg",
"cos", "cos",
"arccos",
"sin",
"arcsin",
"tan",
"arctan",
"arctan2",
"cosh", "cosh",
"arccosh", "cov",
"sinh", "deg2rad",
"arcsinh", "dense_dot",
"tanh", "digamma",
"arctanh", "divmod",
"dot",
"eq",
"equal",
"erf", "erf",
"erfc", "erfc",
"erfcinv",
"erfcx", "erfcx",
"erfinv", "erfinv",
"erfcinv", "exp",
"owens_t", "exp2",
"expit",
"expm1",
"floor",
"floor_div",
"gamma", "gamma",
"gammaln",
"psi",
"digamma",
"tri_gamma",
"polygamma",
"chi2sf",
"gammainc", "gammainc",
"gammaincc", "gammaincc",
"gammau",
"gammal",
"gammaincinv",
"gammainccinv", "gammainccinv",
"j0", "gammaincinv",
"j1", "gammal",
"jv", "gammaln",
"gammau",
"ge",
"greater",
"greater_equal",
"gt",
"hyp2f1",
"i0", "i0",
"i1", "i1",
"imag",
"int_div",
"invert",
"iround",
"isclose",
"isinf",
"isnan",
"isneginf",
"isposinf",
"iv", "iv",
"ive", "ive",
"j0",
"j1",
"jv",
"kn", "kn",
"kv", "kv",
"kve", "kve",
"sigmoid", "largest",
"expit", "le",
"softplus", "less",
"log1pexp", "less_equal",
"log",
"log1mexp", "log1mexp",
"betainc", "log1p",
"betaincinv", "log1pexp",
"real", "log2",
"imag", "log10",
"angle", "logaddexp",
"complex", "logsumexp",
"conj", "lt",
"conjugate", "matmul",
"complex_from_polar", "matvec",
"sum", "max",
"prod", "max_and_argmax",
"maximum",
"mean", "mean",
"median", "median",
"var", "min",
"std",
"std",
"maximum",
"minimum", "minimum",
"divmod",
"add",
"sub",
"mul",
"true_div",
"int_div",
"floor_div",
"ceil_intdiv",
"mod", "mod",
"pow", "mul",
"clip", "nan_to_num",
"dot", "neg",
"dense_dot", "neq",
"tensordot", "not_equal",
"or_",
"outer", "outer",
"any", "owens_t",
"all", "polygamma",
"ptp", "pow",
"power", "power",
"logaddexp", "prod",
"logsumexp", "psi",
"hyp2f1", "ptp",
"nan_to_num", "rad2deg",
"real",
"reciprocal",
"round",
"round_half_away_from_zero",
"round_half_to_even",
"sgn",
"sigmoid",
"sign",
"sin",
"sinh",
"smallest",
"softplus",
"sqr",
"sqrt",
"square",
"std",
"std",
"sub",
"sum",
"tan",
"tanh",
"tensordot",
"tri_gamma",
"true_div",
"trunc",
"var",
"vecdot",
"vecmat",
"xor",
] ]
...@@ -1114,19 +1114,19 @@ def kron(a, b): ...@@ -1114,19 +1114,19 @@ def kron(a, b):
__all__ = [ __all__ = [
"pinv",
"inv",
"trace",
"matrix_dot",
"det", "det",
"eig", "eig",
"eigh", "eigh",
"svd", "inv",
"kron",
"lstsq", "lstsq",
"matrix_dot",
"matrix_power", "matrix_power",
"norm", "norm",
"pinv",
"slogdet", "slogdet",
"svd",
"tensorinv", "tensorinv",
"tensorsolve", "tensorsolve",
"kron", "trace",
] ]
...@@ -999,4 +999,4 @@ def root( ...@@ -999,4 +999,4 @@ def root(
return solution, success return solution, success
__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"] __all__ = ["minimize", "minimize_scalar", "root", "root_scalar"]
...@@ -689,4 +689,4 @@ def pad( ...@@ -689,4 +689,4 @@ def pad(
return cast(TensorVariable, op) return cast(TensorVariable, op)
__all__ = ["pad", "flip"] __all__ = ["flip", "pad"]
...@@ -2132,42 +2132,42 @@ def permutation(x, **kwargs): ...@@ -2132,42 +2132,42 @@ def permutation(x, **kwargs):
__all__ = [ __all__ = [
"permutation", "bernoulli",
"choice", "beta",
"integers",
"categorical",
"multinomial",
"betabinom", "betabinom",
"nbinom",
"binomial", "binomial",
"laplace", "categorical",
"bernoulli",
"truncexpon",
"wald",
"invgamma",
"halfcauchy",
"cauchy", "cauchy",
"hypergeometric", "chisquare",
"geometric", "choice",
"poisson",
"dirichlet", "dirichlet",
"multivariate_normal",
"vonmises",
"logistic",
"weibull",
"exponential", "exponential",
"gumbel",
"pareto",
"chisquare",
"gamma", "gamma",
"lognormal", "gengamma",
"geometric",
"gumbel",
"halfcauchy",
"halfnormal", "halfnormal",
"hypergeometric",
"integers",
"invgamma",
"laplace",
"logistic",
"lognormal",
"multinomial",
"multivariate_normal",
"nbinom",
"negative_binomial",
"normal", "normal",
"beta", "pareto",
"triangular", "permutation",
"uniform", "poisson",
"standard_normal", "standard_normal",
"negative_binomial",
"gengamma",
"t", "t",
"triangular",
"truncexpon",
"uniform",
"vonmises",
"wald",
"weibull",
] ]
...@@ -392,7 +392,7 @@ def local_dimshuffle_lift(fgraph, node): ...@@ -392,7 +392,7 @@ def local_dimshuffle_lift(fgraph, node):
ret = inode.op(*new_inputs, return_list=True) ret = inode.op(*new_inputs, return_list=True)
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order] new_order = [(x == "x" and "x") or inode.op.new_order[x] for x in new_order]
inp = inode.inputs[0] inp = inode.inputs[0]
if is_dimshuffle_useless(new_order, inp): if is_dimshuffle_useless(new_order, inp):
......
...@@ -350,8 +350,10 @@ def local_lift_through_linalg( ...@@ -350,8 +350,10 @@ def local_lift_through_linalg(
outer_op = node.op outer_op = node.op
if y.owner and ( if y.owner and (
isinstance(y.owner.op, Blockwise) (
and isinstance(y.owner.op.core_op, BlockDiagonal) isinstance(y.owner.op, Blockwise)
and isinstance(y.owner.op.core_op, BlockDiagonal)
)
or isinstance(y.owner.op, KroneckerProduct) or isinstance(y.owner.op, KroneckerProduct)
): ):
input_matrices = y.owner.inputs input_matrices = y.owner.inputs
......
...@@ -448,8 +448,7 @@ class ShapeFeature(Feature): ...@@ -448,8 +448,7 @@ class ShapeFeature(Feature):
assert all( assert all(
( (
not hasattr(r.type, "shape") not hasattr(r.type, "shape")
or r.type.shape[i] != 1 or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1)
and other_r.type.shape[i] != 1
) )
or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
......
...@@ -2088,18 +2088,18 @@ def qr( ...@@ -2088,18 +2088,18 @@ def qr(
__all__ = [ __all__ = [
"block_diag",
"cho_solve",
"cholesky", "cholesky",
"solve",
"eigvalsh", "eigvalsh",
"expm", "expm",
"solve_discrete_lyapunov",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
"cho_solve",
"lu", "lu",
"lu_factor", "lu_factor",
"lu_solve", "lu_solve",
"qr", "qr",
"solve",
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_discrete_lyapunov",
"solve_triangular",
] ]
...@@ -812,11 +812,11 @@ def betaln(a, b): ...@@ -812,11 +812,11 @@ def betaln(a, b):
__all__ = [ __all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
"logit",
"beta", "beta",
"betaln", "betaln",
"factorial",
"log_softmax",
"logit",
"poch",
"softmax",
] ]
...@@ -140,4 +140,4 @@ def as_symbolic_None(x, **kwargs): ...@@ -140,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return NoneConst return NoneConst
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"] __all__ = ["NoneConst", "NoneSliceConst", "make_slice", "none_type_t", "slicetype"]
...@@ -14,16 +14,16 @@ import numpy as np ...@@ -14,16 +14,16 @@ import numpy as np
__all__ = [ __all__ = [
"get_unbound_function",
"maybe_add_to_os_environ_pathlist",
"subprocess_Popen",
"call_subprocess_Popen",
"output_subprocess_Popen",
"LOCAL_BITWIDTH", "LOCAL_BITWIDTH",
"PYTHON_INT_BITWIDTH",
"NPY_RAVEL_AXIS",
"NDARRAY_C_VERSION", "NDARRAY_C_VERSION",
"NPY_RAVEL_AXIS",
"PYTHON_INT_BITWIDTH",
"NoDuplicateOptWarningFilter", "NoDuplicateOptWarningFilter",
"call_subprocess_Popen",
"get_unbound_function",
"maybe_add_to_os_environ_pathlist",
"output_subprocess_Popen",
"subprocess_Popen",
] ]
......
...@@ -491,7 +491,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa ...@@ -491,7 +491,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
if isinstance(dim, str): if isinstance(dim, str):
dims_dict = {dim: 1} dims_dict = {dim: 1}
elif isinstance(dim, Sequence) and not isinstance(dim, dict): elif isinstance(dim, Sequence) and not isinstance(dim, dict):
dims_dict = {d: 1 for d in dim} dims_dict = dict.fromkeys(dim, 1)
elif isinstance(dim, dict): elif isinstance(dim, dict):
dims_dict = {} dims_dict = {}
for name, val in dim.items(): for name, val in dim.items():
......
...@@ -651,10 +651,10 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -651,10 +651,10 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
) )
else: else:
# Default to 5 for head and tail # Default to 5 for head and tail
indexers = {dim: 5 for dim in self.type.dims} indexers = dict.fromkeys(self.type.dims, 5)
elif not isinstance(indexers, dict): elif not isinstance(indexers, dict):
indexers = {dim: indexers for dim in self.type.dims} indexers = dict.fromkeys(self.type.dims, indexers)
if kind == "head": if kind == "head":
indices = {dim: slice(None, value) for dim, value in indexers.items()} indices = {dim: slice(None, value) for dim, value in indexers.items()}
......
...@@ -2651,9 +2651,8 @@ class TestArithmeticCast: ...@@ -2651,9 +2651,8 @@ class TestArithmeticCast:
# behavior. # behavior.
return return
if ( if {a_type, b_type} == {"complex128", "float32"} or (
{a_type, b_type} == {"complex128", "float32"} {a_type, b_type} == {"complex128", "float16"}
or {a_type, b_type} == {"complex128", "float16"}
and set(combo) == {"scalar", "array"} and set(combo) == {"scalar", "array"}
and pytensor_dtype == "complex128" and pytensor_dtype == "complex128"
and numpy_dtype == "complex64" and numpy_dtype == "complex64"
......
...@@ -48,11 +48,8 @@ class MyOp(Op): ...@@ -48,11 +48,8 @@ class MyOp(Op):
return self.name return self.name
def __eq__(self, other): def __eq__(self, other):
return ( return self is other or (
self is other isinstance(other, MyOp) and self.x is not None and self.x == other.x
or isinstance(other, MyOp)
and self.x is not None
and self.x == other.x
) )
def __hash__(self): def __hash__(self):
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
# ruff: noqa: E402
import pytest import pytest
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论