提交 861816d9 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Apply rule RUF005

上级 b6316e83
...@@ -324,7 +324,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -324,7 +324,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
batch_sampling_keys = jax.random.split(rng_key, np.prod(size)) batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
# Ravel the batch dimensions because vmap only works along a single axis # Ravel the batch dimensions because vmap only works along a single axis
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:]) raveled_batch_a = a.reshape((-1, *a.shape[batch_ndim:]))
if p is None: if p is None:
raveled_sample = jax.vmap( raveled_sample = jax.vmap(
lambda key, a: jax.random.choice( lambda key, a: jax.random.choice(
...@@ -332,7 +332,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): ...@@ -332,7 +332,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
) )
)(batch_sampling_keys, raveled_batch_a) )(batch_sampling_keys, raveled_batch_a)
else: else:
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:]) raveled_batch_p = p.reshape((-1, *p.shape[batch_ndim:]))
raveled_sample = jax.vmap( raveled_sample = jax.vmap(
lambda key, a, p: jax.random.choice( lambda key, a, p: jax.random.choice(
key, a, shape=core_shape, replace=False, p=p key, a, shape=core_shape, replace=False, p=p
...@@ -363,7 +363,7 @@ def jax_sample_fn_permutation(op, node): ...@@ -363,7 +363,7 @@ def jax_sample_fn_permutation(op, node):
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:]) x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
batch_sampling_keys = jax.random.split(rng_key, np.prod(size)) batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:]) raveled_batch_x = x.reshape((-1, *x.shape[batch_ndim:]))
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))( raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
batch_sampling_keys, raveled_batch_x batch_sampling_keys, raveled_batch_x
) )
......
...@@ -2104,7 +2104,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2104,7 +2104,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# are read and written. # are read and written.
# This way, there will be no information overwritten # This way, there will be no information overwritten
# before it is read (as it used to happen). # before it is read (as it used to happen).
shape = (pdx,) + output_storage[idx][0].shape[1:] shape = (pdx, *output_storage[idx][0].shape[1:])
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype) tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][:pdx] tmp[:] = output_storage[idx][0][:pdx]
output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[ output_storage[idx][0][: store_steps[idx] - pdx] = output_storage[
...@@ -2113,7 +2113,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2113,7 +2113,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
output_storage[idx][0][store_steps[idx] - pdx :] = tmp output_storage[idx][0][store_steps[idx] - pdx :] = tmp
del tmp del tmp
else: else:
shape = (store_steps[idx] - pdx,) + output_storage[idx][0].shape[1:] shape = (store_steps[idx] - pdx, *output_storage[idx][0].shape[1:])
tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype) tmp = np.empty(shape, dtype=node.outputs[idx].type.dtype)
tmp[:] = output_storage[idx][0][pdx:] tmp[:] = output_storage[idx][0][pdx:]
output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[ output_storage[idx][0][store_steps[idx] - pdx :] = output_storage[
...@@ -2304,7 +2304,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2304,7 +2304,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if x is None: if x is None:
scan_outs.append(None) scan_outs.append(None)
else: else:
scan_outs.append((Shape_i(0)(o),) + x[1:]) scan_outs.append((Shape_i(0)(o), *x[1:]))
return scan_outs return scan_outs
def connection_pattern(self, node): def connection_pattern(self, node):
......
...@@ -4051,8 +4051,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1): ...@@ -4051,8 +4051,8 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
# Re-order axes so they correspond to diagonals at axis1, axis2 # Re-order axes so they correspond to diagonals at axis1, axis2
axes = list(range(diag.type.ndim - 1)) axes = list(range(diag.type.ndim - 1))
last_idx = axes[-1] last_idx = axes[-1]
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:] axes = [*axes[:axis1], last_idx + 1, *axes[axis1:]]
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:] axes = [*axes[:axis2], last_idx + 2, *axes[axis2:]]
result = result.transpose(axes) result = result.transpose(axes)
return AllocDiag( return AllocDiag(
...@@ -4525,7 +4525,7 @@ def _make_along_axis_idx(arr_shape, indices, axis): ...@@ -4525,7 +4525,7 @@ def _make_along_axis_idx(arr_shape, indices, axis):
if dim is None: if dim is None:
fancy_index.append(indices) fancy_index.append(indices)
else: else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :] ind_shape = (*shape_ones[:dim], -1, *shape_ones[dim + 1 :])
fancy_index.append(arange(n).reshape(ind_shape)) fancy_index.append(arange(n).reshape(ind_shape))
return tuple(fancy_index) return tuple(fancy_index)
......
...@@ -244,7 +244,7 @@ def get_conv_gradweights_shape( ...@@ -244,7 +244,7 @@ def get_conv_gradweights_shape(
for i in range(len(subsample)) for i in range(len(subsample))
) )
if unshared: if unshared:
return (nchan,) + top_shape[2:] + (nkern,) + out_shp return (nchan, *top_shape[2:], nkern, *out_shp)
else: else:
return (nchan, nkern, *out_shp) return (nchan, nkern, *out_shp)
...@@ -2906,9 +2906,9 @@ class AbstractConv_gradWeights(BaseAbstractConv): ...@@ -2906,9 +2906,9 @@ class AbstractConv_gradWeights(BaseAbstractConv):
def correct_for_groups(mat): def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[1] * self.num_groups mshp1 = mat.shape[1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:]) mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim))) mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :]) mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
return mat return mat
if self.num_groups > 1: if self.num_groups > 1:
...@@ -3283,7 +3283,7 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -3283,7 +3283,7 @@ class AbstractConv_gradInputs(BaseAbstractConv):
def correct_for_groups(mat): def correct_for_groups(mat):
mshp0 = mat.shape[0] // self.num_groups mshp0 = mat.shape[0] // self.num_groups
mshp1 = mat.shape[-self.convdim - 1] * self.num_groups mshp1 = mat.shape[-self.convdim - 1] * self.num_groups
mat = mat.reshape((self.num_groups, mshp0) + mat.shape[1:]) mat = mat.reshape((self.num_groups, mshp0, *mat.shape[1:]))
if self.unshared: if self.unshared:
# for 2D -> (1, 2, 3, 0, 4, 5, 6) # for 2D -> (1, 2, 3, 0, 4, 5, 6)
mat = mat.transpose( mat = mat.transpose(
...@@ -3294,14 +3294,16 @@ class AbstractConv_gradInputs(BaseAbstractConv): ...@@ -3294,14 +3294,16 @@ class AbstractConv_gradInputs(BaseAbstractConv):
) )
) )
mat = mat.reshape( mat = mat.reshape(
(mshp0,) (
+ mat.shape[1 : 1 + self.convdim] mshp0,
+ (mshp1,) *mat.shape[1 : 1 + self.convdim],
+ mat.shape[-self.convdim :] mshp1,
*mat.shape[-self.convdim :],
)
) )
else: else:
mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim))) mat = mat.transpose((1, 0, 2, *range(3, 3 + self.convdim)))
mat = mat.reshape((mshp0, mshp1) + mat.shape[-self.convdim :]) mat = mat.reshape((mshp0, mshp1, *mat.shape[-self.convdim :]))
return mat return mat
kern = correct_for_groups(kern) kern = correct_for_groups(kern)
......
...@@ -563,7 +563,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -563,7 +563,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert y_clone != y assert y_clone != y
y_clone.name = "y_clone" y_clone.name = "y_clone"
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0] out_new = test_ofg.make_node(*([*out.owner.inputs[:1], y_clone])).outputs[0]
assert "on_unused_input" in out_new.owner.op.kwargs assert "on_unused_input" in out_new.owner.op.kwargs
assert out_new.owner.op.shared_inputs == [y_clone] assert out_new.owner.op.shared_inputs == [y_clone]
......
...@@ -144,12 +144,12 @@ class multiple_outputs_numeric_grad: ...@@ -144,12 +144,12 @@ class multiple_outputs_numeric_grad:
t = t.flatten() t = t.flatten()
t[pos] += _eps t[pos] += _eps
t = t.reshape(pt[i].shape) t = t.reshape(pt[i].shape)
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :])) f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
_g.append(np.asarray((f_eps - f_x) / _eps)) _g.append(np.asarray((f_eps - f_x) / _eps))
gx.append(np.asarray(_g).reshape(pt[i].shape)) gx.append(np.asarray(_g).reshape(pt[i].shape))
else: else:
t = np.array(pt[i] + _eps) t = np.array(pt[i] + _eps)
f_eps = f(*(pt[:i] + [t] + pt[i + 1 :])) f_eps = f(*([*pt[:i], t, *pt[i + 1 :]]))
gx.append(np.asarray((f_eps - f_x) / _eps)) gx.append(np.asarray((f_eps - f_x) / _eps))
self.gx = gx self.gx = gx
......
...@@ -266,7 +266,7 @@ class TestConvGradInputsShape: ...@@ -266,7 +266,7 @@ class TestConvGradInputsShape:
computed_image_shape = get_conv_gradinputs_shape( computed_image_shape = get_conv_gradinputs_shape(
kernel_shape, output_shape, b, (2, 3), (d, d) kernel_shape, output_shape, b, (2, 3), (d, d)
) )
image_shape_with_None = image_shape[:2] + (None, None) image_shape_with_None = (*image_shape[:2], None, None)
assert computed_image_shape == image_shape_with_None assert computed_image_shape == image_shape_with_None
# compute the kernel_shape given this output_shape # compute the kernel_shape given this output_shape
...@@ -276,7 +276,7 @@ class TestConvGradInputsShape: ...@@ -276,7 +276,7 @@ class TestConvGradInputsShape:
# if border_mode == 'half', the shape should be None # if border_mode == 'half', the shape should be None
if b == "half": if b == "half":
kernel_shape_with_None = kernel_shape[:2] + (None, None) kernel_shape_with_None = (*kernel_shape[:2], None, None)
assert computed_kernel_shape == kernel_shape_with_None assert computed_kernel_shape == kernel_shape_with_None
else: else:
assert computed_kernel_shape == kernel_shape assert computed_kernel_shape == kernel_shape
...@@ -285,7 +285,7 @@ class TestConvGradInputsShape: ...@@ -285,7 +285,7 @@ class TestConvGradInputsShape:
computed_kernel_shape = get_conv_gradweights_shape( computed_kernel_shape = get_conv_gradweights_shape(
kernel_shape, output_shape, b, (2, 3), (d, d) kernel_shape, output_shape, b, (2, 3), (d, d)
) )
kernel_shape_with_None = kernel_shape[:2] + (None, None) kernel_shape_with_None = (*kernel_shape[:2], None, None)
assert computed_kernel_shape == kernel_shape_with_None assert computed_kernel_shape == kernel_shape_with_None
......
...@@ -1019,7 +1019,7 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -1019,7 +1019,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
) )
# create some invalid indices to test the mode # create some invalid indices to test the mode
if mode in ("wrap", "clip"): if mode in ("wrap", "clip"):
multi_index = (multi_index[0] - 1,) + multi_index[1:] multi_index = (multi_index[0] - 1, *multi_index[1:])
# test with scalars and higher-dimensional indices # test with scalars and higher-dimensional indices
if index_ndim == 0: if index_ndim == 0:
multi_index = tuple(i[-1] for i in multi_index) multi_index = tuple(i[-1] for i in multi_index)
......
...@@ -1272,7 +1272,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1272,7 +1272,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
if len(inc_shape) == len(data_shape) and ( if len(inc_shape) == len(data_shape) and (
len(inc_shapes) == 0 or inc_shape[0] != 1 len(inc_shapes) == 0 or inc_shape[0] != 1
): ):
inc_shape = (n_to_inc,) + inc_shape[1:] inc_shape = (n_to_inc, *inc_shape[1:])
# Symbolic variable with increment value. # Symbolic variable with increment value.
inc_var_static_shape = tuple( inc_var_static_shape = tuple(
...@@ -2822,15 +2822,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True ...@@ -2822,15 +2822,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
(np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]), (np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:2] + (slice(None, None),), (*test_idx[:2], slice(None, None)),
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None),) + test_idx[:1], (slice(None, None), *test_idx[:1]),
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None) + test_idx[1:2], (slice(None, None), None, *test_idx[1:2]),
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
...@@ -2842,15 +2842,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True ...@@ -2842,15 +2842,15 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2], (*test_idx[:1], slice(None, None), *test_idx[1:2]),
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (slice(None, None),) + test_idx[1:2] + (slice(None, None),), (*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)),
), ),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
test_idx[:1] + (None,) + test_idx[1:2], (*test_idx[:1], None, *test_idx[1:2]),
), ),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))), (np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])), (np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论