提交 1da2891c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add flake8-comprehensions plugin

上级 bc878138
......@@ -33,6 +33,8 @@ repos:
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-comprehensions
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
......
......@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
return False
if not op.is_inline:
return False
return clone_replace(
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
)
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
# We want to run this before the first merge optimizer
......
......@@ -504,7 +504,7 @@ def grad(
if not isinstance(wrt, Sequence):
_wrt: List[Variable] = [wrt]
else:
_wrt = [x for x in wrt]
_wrt = list(wrt)
outputs = []
if cost is not None:
......@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
pgrads = dict(zip(params, grads))
# separate wrt from end grads:
wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end)
wrt_grads = [pgrads[k] for k in wrt]
end_grads = [pgrads[k] for k in end]
if details:
return wrt_grads, end_grads, start_grads, cost_grads
......
......@@ -1629,7 +1629,7 @@ def as_string(
multi.add(op2)
else:
seen.add(input.owner)
multi_list = [x for x in multi]
multi_list = list(multi)
done: Set = set()
def multi_index(x):
......
......@@ -142,7 +142,7 @@ def graph_replace(
raise ValueError(f"{key} is not a part of graph")
sorted_replacements = sorted(
tuple(fg_replace.items()),
fg_replace.items(),
# sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort),
reverse=True,
......
......@@ -2575,8 +2575,8 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
for i in range(len(loop_timing)):
loop_times = ""
if loop_process_count[i]:
d = list(
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1]))
d = sorted(
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
)
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
if len(d) > 5:
......
......@@ -633,11 +633,11 @@ class CLinker(Linker):
# The orphans field is listified to ensure a consistent order.
# list(fgraph.orphans.difference(self.outputs))
self.orphans = list(
self.orphans = [
r
for r in self.variables
if isinstance(r, AtomicVariable) and r not in self.inputs
)
]
# C type constants (pytensor.scalar.ScalarType). They don't request an object
self.consts = []
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts
......
......@@ -810,7 +810,7 @@ class ParamsType(CType):
struct_extract_method=struct_extract_method,
)
return list(sorted(list(c_support_code_set))) + [final_struct_code]
return sorted(c_support_code_set) + [final_struct_code]
def c_code_cache_version(self):
return ((3,), tuple(t.c_code_cache_version() for t in self.types))
......
......@@ -41,7 +41,7 @@ def jax_funcify_CAReduce(op, **kwargs):
elif scalar_op_name:
scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis))
to_reduce = sorted(axis, reverse=True)
if to_reduce:
# In this case, we need to use the `jax.lax` function (if there
......
......@@ -361,7 +361,7 @@ def create_multiaxis_reducer(
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = reversed(sorted(axes))
to_reduce = sorted(axes, reverse=True)
careduce_lines_src = []
var_name = input_name
......
......@@ -796,7 +796,7 @@ class Print(Op):
return output_gradients
def R_op(self, inputs, eval_points):
return [x for x in eval_points]
return list(eval_points)
def __setstate__(self, dct):
dct.setdefault("global_fn", _print_fn)
......
......@@ -492,7 +492,7 @@ def scan(
# wrap sequences in a dictionary if they are not already dictionaries
for i in range(n_seqs):
if not isinstance(seqs[i], dict):
seqs[i] = dict([("input", seqs[i]), ("taps", [0])])
seqs[i] = {"input": seqs[i], "taps": [0]}
elif seqs[i].get("taps", None) is not None:
seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"])
elif seqs[i].get("taps", None) is None:
......@@ -504,7 +504,7 @@ def scan(
if outs_info[i] is not None:
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = dict([("initial", outs_info[i]), ("taps", [-1])])
outs_info[i] = {"initial": outs_info[i], "taps": [-1]}
elif (
outs_info[i].get("initial", None) is None
and outs_info[i].get("taps", None) is not None
......
......@@ -1718,12 +1718,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
arg.shape[0]
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
]
store_steps += [
arg
for arg in inputs[
self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot
]
]
store_steps += list(
inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot]
)
# 2.1 Create storage space for outputs
for idx in range(self.n_outs):
......@@ -2270,7 +2267,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
offset = 1 + info.n_seqs
scan_outs = [x for x in input_shapes[offset : offset + n_outs]]
scan_outs = list(input_shapes[offset : offset + n_outs])
offset += n_outs
outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
for x in range(info.n_nit_sot):
......@@ -2301,7 +2298,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp.append(v_shp_i[0])
scan_outs.append(tuple(shp))
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
scan_outs += list(input_shapes[offset : offset + info.n_shared_outs])
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if info.as_while:
......
......@@ -388,7 +388,7 @@ def push_out_non_seq_scan(fgraph, node):
if out in local_fgraph_outs_set:
x = node.outputs[local_fgraph_outs_map[out]]
y = replace_with_out[idx]
y_shape = [shp for shp in y.shape]
y_shape = list(y.shape)
replace_with[x] = at.alloc(y, node.inputs[0], *y_shape)
# We need to add one extra dimension to the outputs
......
......@@ -283,7 +283,7 @@ class RandomVariable(Op):
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
return [None, [s for s in shape]]
return [None, list(shape)]
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
res = super().__call__(rng, size, dtype, *args, **kwargs)
......
......@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
)
if len(compatible_dims) > 0:
optimized_dimshuffle_order = list(
optimized_dimshuffle_order = [
ax
for i, ax in enumerate(dimshuffle_order)
if (i not in axis) or (ax != "x")
)
]
# Removing leading 'x' (since it will be done automatically)
while (
......@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
# figure out which axes were in the original sum
newaxis = list(tuple(node_inps.owner.op.axis))
newaxis = list(node_inps.owner.op.axis)
for i in node.op.axis:
new_i = i
for ii in node_inps.owner.op.axis:
......
......@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
"""
_t = at.as_tensor_variable(t)
pattern = ["x"] * n_ones + [i for i in range(_t.type.ndim)]
pattern = ["x"] * n_ones + list(range(_t.type.ndim))
return _t.dimshuffle(pattern)
......@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
"""
_t = at.as_tensor_variable(t)
pattern = [i for i in range(_t.type.ndim)] + ["x"] * n_ones
pattern = list(range(_t.type.ndim)) + ["x"] * n_ones
return _t.dimshuffle(pattern)
......@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
if axis < 0:
axis += ndim
pattern = [i for i in range(_t.type.ndim)]
pattern = list(range(_t.type.ndim))
pattern.insert(axis, "x")
return _t.dimshuffle(pattern)
......
......@@ -2604,7 +2604,7 @@ class AdvancedSubtensor(Op):
ishapes[0], index_shapes, indices_are_shapes=True
)
assert node.outputs[0].ndim == len(res_shape)
return [[s for s in res_shape]]
return [list(res_shape)]
def perform(self, node, inputs, out_):
(out,) = out_
......
[flake8]
select = C,E,F,W
ignore = E203,E231,E501,E741,W503,W504,C901
ignore = E203,E231,E501,E741,W503,W504,C408,C901
per-file-ignores =
**/__init__.py:F401,E402,F403
pytensor/tensor/linalg.py:F401,F403
......
......@@ -73,7 +73,7 @@ class TestNodeFinder:
assert hasattr(g, "get_nodes")
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if len([t for t in g.get_nodes(type)]) != num:
if len(list(g.get_nodes(type))) != num:
raise Exception("Expected: %i times %s" % (num, type))
new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot)
......@@ -82,7 +82,7 @@ class TestNodeFinder:
assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if len([t for t in g.get_nodes(type)]) != num:
if len(list(g.get_nodes(type))) != num:
raise Exception("Expected: %i times %s" % (num, type))
......
......@@ -87,7 +87,7 @@ class TestOp:
r1, r2 = MyType(1)(), MyType(2)()
node = MyOp.make_node(r1, r2)
# Are the inputs what I provided?
assert [x for x in node.inputs] == [r1, r2]
assert list(node.inputs) == [r1, r2]
# Are the outputs what I expect?
assert [x.type for x in node.outputs] == [MyType(3)]
assert node.outputs[0].owner is node and node.outputs[0].index == 0
......
......@@ -1123,7 +1123,7 @@ class TestFusion:
out = dot(x, y) + x + y + z
f = function([x, y, z], out, mode=self.mode)
topo = [n for n in f.maker.fgraph.toposort()]
topo = list(f.maker.fgraph.toposort())
assert len(topo) == 2
assert topo[-1].op.inplace_pattern
......
......@@ -3994,9 +3994,9 @@ class TestSigmoidUtils:
exp_op = exp
assert is_1pexp(1 + exp_op(x), False) == (False, x)
assert is_1pexp(exp_op(x) + 1, False) == (False, x)
for neg_, exp_arg in map(
lambda x: is_1pexp(x, only_process_constants=False),
[(1 + exp_op(-x)), (exp_op(-x) + 1)],
for neg_, exp_arg in (
is_1pexp(x, only_process_constants=False)
for x in [(1 + exp_op(-x)), (exp_op(-x) + 1)]
):
assert not neg_ and is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp_op(x), False) is None
......
......@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
y_val_fn = function(
[x] + list(s), y, on_unused_input="ignore", mode=no_rewrites_mode
)
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val]))
y_val = y_val_fn(*([x_val] + list(s_val)))
# This optimization should appear in the canonicalizations
y_opt = rewrite_graph(y, clone=False)
......@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
assert isinstance(y_opt.owner.op, SpecifyShape)
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
y_opt_val = y_opt_fn(*([x_val] + [s_ for s_ in s_val]))
y_opt_val = y_opt_fn(*([x_val] + list(s_val)))
assert np.allclose(y_val, y_opt_val)
......
......@@ -2589,10 +2589,10 @@ TestBatchedDot = makeTester(
op=batched_dot,
expected=(
lambda xs, ys: np.asarray(
list(
[
x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y)
for x, y in zip(xs, ys)
),
],
dtype=aes.upcast(xs.dtype, ys.dtype),
)
),
......@@ -2694,7 +2694,7 @@ def test_batched_dot_not_contiguous():
assert x.strides[0] == direction * np.dtype(config.floatX).itemsize
assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"])
result = f(x, w)
ref_result = np.asarray(list(np.dot(u, v) for u, v in zip(x, w)))
ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w)])
utt.assert_allclose(ref_result, result)
for inverted in (0, 1):
......
......@@ -15,9 +15,7 @@ class TestRealImag:
x = zvector()
rng = np.random.default_rng(23)
xval = np.asarray(
list(
complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)
)
[complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)]
)
assert np.all(xval.real == pytensor.function([x], real(x))(xval))
assert np.all(xval.imag == pytensor.function([x], imag(x))(xval))
......
......@@ -490,50 +490,50 @@ class TestCAReduce(unittest_tools.InferShapeTester):
assert len(axis2) == len(tosum)
tosum = tuple(axis2)
if tensor_op == at_all:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.all(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif tensor_op == at_any:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.any(zv, axis)
if len(tosum) == 0:
zv = zv != 0
elif scalar_op == aes.add:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.add.reduce(zv, axis)
if dtype == "bool":
# np.add of a bool upcast, while CAReduce don't
zv = zv.astype(dtype)
elif scalar_op == aes.mul:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.multiply.reduce(zv, axis)
elif scalar_op == aes.scalar_maximum:
# There is no identity value for the maximum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.maximum.reduce(zv, axis)
elif scalar_op == aes.scalar_minimum:
# There is no identity value for the minimum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.minimum.reduce(zv, axis)
elif scalar_op == aes.or_:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.bitwise_or.reduce(zv, axis)
elif scalar_op == aes.and_:
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = reduce_bitwise_and(zv, axis, dtype=dtype)
elif scalar_op == aes.xor:
# There is no identity value for the xor function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in reversed(sorted(tosum)):
for axis in sorted(tosum, reverse=True):
zv = np.bitwise_xor.reduce(zv, axis)
else:
raise NotImplementedError(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论