提交 861816d9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Apply rule RUF005

上级 b6316e83
......@@ -324,7 +324,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
# Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
raveled_batch_a = a.reshape((-1, *a.shape[batch_ndim:]))
if p is None:
raveled_sample = jax.vmap(
lambda key, a: jax.random.choice(
......@@ -332,7 +332,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
)
)(batch_sampling_keys, raveled_batch_a)
else:
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:])
raveled_batch_p = p.reshape((-1, *p.shape[batch_ndim:]))
raveled_sample = jax.vmap(
lambda key, a, p: jax.random.choice(
key, a, shape=core_shape, replace=False, p=p
......@@ -363,7 +363,7 @@ def jax_sample_fn_permutation(op, node):
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
raveled_batch_x = x.reshape((-1, *x.shape[batch_ndim:]))
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x
)
......
......@@ -2104,7 +2104,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# are read and written.
# This way, there will be no information overwritten
# before it is read (as it used to happen).
shape = (pdx,) + output_storage[idx][0].shape[1:]
shape = (pdx, *output_storage[idx][0].shape[1:])
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][:pdx]
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
......@@ -2113,7 +2113,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0][store_steps[idx] - pdx :] = tmp
del tmp
else:
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:]
shape = (store_steps[idx] - pdx, *output_storage[idx][0].shape[1:])
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][pdx:]
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
......@@ -2304,7 +2304,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if x is None:
scan_outs.append(None)
else:
scan_outs.append((Shape_i(0)(o),) + x[1:])
scan_outs.append((Shape_i(0)(o), *x[1:]))
return scan_outs
def connection_pattern(self, node):
......
......@@ -4051,8 +4051,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
# Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(diag.type.ndim - 1))
last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
axes = [*axes[:axis1], last_idx + 1, *axes[axis1:]]
axes = [*axes[:axis2], last_idx + 2, *axes[axis2:]]
result = result.transpose(axes)
return AllocDiag(
......@@ -4525,7 +4525,7 @@ def _make_along_axis_idx(arr_shape, indices, axis):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
ind_shape = (*shape_ones[:dim], -1, *shape_ones[dim + 1 :])
fancy_index.append(arange(n).reshape(ind_shape))
return tuple(fancy_index)
......
......@@ -244,7 +244,7 @@ def get_conv_gradweights_shape(
for i in range(len(subsample))
)
if unshared:
return (nchan,) + top_shape[2:] + (nkern,) + out_shp
return (nchan, *top_shape[2:], nkern, *out_shp)
else:
return (nchan, nkern, *out_shp)
......@@ -2906,9 +2906,9 @@ class AbstractConv_gradWeights(BaseAbstractConv):
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :])
mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
return mat
if self.num_groups > 1:
......@@ -3283,7 +3283,7 @@ class AbstractConv_gradInputs(BaseAbstractConv):
def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[-self.convdim - 1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:])
mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
if self.unshared:
# for 2D -> (1, 2, 3, 0, 4, 5, 6)
mat = mat.transpose(
......@@ -3294,14 +3294,16 @@ class AbstractConv_gradInputs(BaseAbstractConv):
)
)
mat = mat.reshape(
(mshp0,)
+ mat.shape[1 : 1 + self.convdim]
+ (mshp1,)
+ mat.shape[-self.convdim :]
(
mshp0,
*mat.shape[1 : 1 + self.convdim],
mshp1,
*mat.shape[-self.convdim :],
)
)
else:
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :])
mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
return mat
kern = correct_for_groups(kern)
......
......@@ -563,7 +563,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert y_clone != y
y_clone.name = "y_clone"
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
out_new = test_ofg.make_node(*([*out.owner.inputs[:1], y_clone])).outputs[0]
assert "on_unused_input" in out_new.owner.op.kwargs
assert out_new.owner.op.shared_inputs == [y_clone]
......
......@@ -144,12 +144,12 @@ class multiple_outputs_numeric_grad:
t = t.flatten()
t[pos] += _eps
t = t.reshape(pt[i].shape)
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :]))
f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
_g.append(np.asarray((f_eps - f_x) / _eps))
gx.append(np.asarray(_g).reshape(pt[i].shape))
else:
t = np.array(pt[i] + _eps)
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :]))
f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
gx.append(np.asarray((f_eps - f_x) / _eps))
self.gx = gx
......
......@@ -266,7 +266,7 @@ class TestConvGradInputsShape:
computed_image_shape = get_conv_gradinputs_shape(
kernel_shape, output_shape, b, (2, 3), (d, d)
)
image_shape_with_None = image_shape[:2] + (None, None)
image_shape_with_None = (*image_shape[:2], None, None)
assert computed_image_shape == image_shape_with_None
# compute the kernel_shape given this output_shape
......@@ -276,7 +276,7 @@ class TestConvGradInputsShape:
# if border_mode == 'half', the shape should be None
if b == "half":
kernel_shape_with_None = kernel_shape[:2] + (None, None)
kernel_shape_with_None = (*kernel_shape[:2], None, None)
assert computed_kernel_shape == kernel_shape_with_None
else:
assert computed_kernel_shape == kernel_shape
......@@ -285,7 +285,7 @@ class TestConvGradInputsShape:
computed_kernel_shape = get_conv_gradweights_shape(
kernel_shape, output_shape, b, (2, 3), (d, d)
)
kernel_shape_with_None = kernel_shape[:2] + (None, None)
kernel_shape_with_None = (*kernel_shape[:2], None, None)
assert computed_kernel_shape == kernel_shape_with_None
......
......@@ -1019,7 +1019,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
)
# create some invalid indices to test the mode
if mode in ("wrap", "clip"):
multi_index = (multi_index[0] - 1,) + multi_index[1:]
multi_index = (multi_index[0] - 1, *multi_index[1:])
# test with scalars and higher-dimensional indices
if index_ndim == 0:
multi_index = tuple(i[-1] for i in multi_index)
......
......@@ -1272,7 +1272,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
if len(inc_shape) == len(data_shape) and (
len(inc_shapes) == 0 or inc_shape[0] != 1
):
inc_shape = (n_to_inc,) + inc_shape[1:]
inc_shape = (n_to_inc, *inc_shape[1:])
# Symbolic variable with increment value.
inc_var_static_shape = tuple(
......@@ -2822,15 +2822,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:2] + (slice(None, None),),
(*test_idx[:2], slice(None, None)),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None),) + test_idx[:1],
(slice(None, None), *test_idx[:1]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None) + test_idx[1:2],
(slice(None, None), None, *test_idx[1:2]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
......@@ -2842,15 +2842,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2],
(*test_idx[:1], slice(None, None), *test_idx[1:2]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2] + (slice(None, None),),
(*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (None,) + test_idx[1:2],
(*test_idx[:1], None, *test_idx[1:2]),
),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论