提交 10105bea authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Don't specify zip strict kwarg in hot loops

It seems to add a non-trivial 100ns
上级 5335a680
......@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format = true
[tool.ruff.lint]
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
unfixable = [
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
......
......@@ -873,7 +873,6 @@ class OpFromGraph(Op, HasInnerGraph):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
# strict=False because asserted above
for output, variable in zip(outputs, variables, strict=False):
# zip strict not specified because we are in a hot loop
for output, variable in zip(outputs, variables):
output[0] = variable
......@@ -924,7 +924,8 @@ class Function:
# Reinitialize each container's 'provided' counter
if trust_input:
for arg_container, arg in zip(input_storage, args, strict=False):
# zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
arg_container.storage[0] = arg
else:
for arg_container in input_storage:
......@@ -934,7 +935,8 @@ class Function:
raise TypeError("Too many parameter passed to pytensor function")
# Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False):
# zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
......
......@@ -305,8 +305,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0:
return ls
else:
# strict=False because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False):
# zip strict not specified because we are in a hot loop
for out, t in zip(outputs, input_true_branch):
compute_map[out][0] = 1
val = storage_map[t][0]
if self.as_view:
......@@ -326,8 +326,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0:
return ls
else:
# strict=False because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False):
# zip strict not specified because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch):
compute_map[out][0] = 1
# can't view both outputs unless destroyhandler
# improves
......
......@@ -539,14 +539,14 @@ class WrapLinker(Linker):
def f():
for inputs in input_lists[1:]:
# strict=False because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False):
# zip strict not specified because we are in a hot loop
for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
# strict=False because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
# zip strict not specified because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
......@@ -668,8 +668,8 @@ class JITLinker(PerformLinker):
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)
# strict=False because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val
thunk.inputs = thunk_inputs
......
......@@ -1988,27 +1988,23 @@ class DualLinker(Linker):
)
def f():
# strict=False because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False):
# zip strict not specified because we are in a hot loop
for input1, input2 in zip(i1, i2):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip(
thunks1, thunks2, order1, order2, strict=False
):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
for output, storage in zip(node1.outputs, thunk1.outputs):
if output in no_recycling:
storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
for output, storage in zip(node2.outputs, thunk2.outputs):
if output in no_recycling:
storage[0] = None
try:
thunk1()
thunk2()
for output1, output2 in zip(
thunk1.outputs, thunk2.outputs, strict=False
):
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
self.checker(output1, output2)
except Exception:
raise_with_op(fgraph, node1)
......
......@@ -312,10 +312,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else:
def py_perform_return(inputs):
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
for out_type, out in zip(output_types, py_perform(inputs))
)
@numba_njit
......
......@@ -166,10 +166,7 @@ class _CythonWrapper(numba.types.WrapperAddressProtocol):
def __call__(self, *args, **kwargs):
# no strict argument because of the JIT
# TODO: check
args = [
dtype(arg)
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
]
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs)
else:
......
......@@ -186,7 +186,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)
......
......@@ -183,7 +183,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
r, c = 0, 0
# no strict argument because it is incompatible with numba
for arr, shape in zip(arrs, shapes): # noqa: B905
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
......
......@@ -219,7 +219,7 @@ def numba_funcify_multiple_integer_vector_indexing(
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer
......@@ -253,7 +253,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out
......@@ -275,7 +275,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out
......@@ -314,7 +314,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba
for idx, val in zip(idxs, vals): # noqa: B905
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
......@@ -342,7 +342,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
raise ValueError("The number of indices and values must match.")
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals): # noqa: B905
for idx, val in zip(idxs, vals):
x[idx] += val
return x
......
......@@ -207,8 +207,8 @@ def streamline(
for x in no_recycling:
x[0] = None
try:
# strict=False because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False):
# zip strict not specified because we are in a hot loop
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
......
......@@ -4416,8 +4416,8 @@ class Composite(ScalarInnerGraphOp):
def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False):
# zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val
def grad(self, inputs, output_grads):
......
......@@ -196,8 +196,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for i in range(n_steps):
carry = inner_fn(*carry, *constant)
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False):
# zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, carry):
storage[0] = out_val
@property
......
......@@ -3589,8 +3589,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough
out_s = []
# strict=False because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False):
# zip strict not specified because we are in a hot loop
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
......
......@@ -712,9 +712,9 @@ class Elemwise(OpenMPOp):
if nout == 1:
variables = [variables]
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs, strict=False)
zip(variables, output_storage, node.outputs)
):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
......@@ -729,11 +729,11 @@ class Elemwise(OpenMPOp):
@staticmethod
def _check_runtime_broadcast(node, inputs):
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable, strict=False)
for input, sinput in zip(inputs, node.inputs, strict=False)
zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)
],
strict=False,
):
......
......@@ -1865,8 +1865,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`")
# strict=False because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
# zip strict not specified because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`")
......
......@@ -44,8 +44,8 @@ def params_broadcast_shapes(
max_fn = maximum if use_pytensor else max
rev_extra_dims: list[int] = []
# strict=False because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
# zip strict not specified because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes):
# We need this in order to use `len`
param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
......@@ -64,12 +64,12 @@ def params_broadcast_shapes(
extra_dims = tuple(reversed(rev_extra_dims))
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0
else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
for ndim_param, param_shape in zip(ndims_params, param_shapes)
]
return bcast_shapes
......@@ -127,10 +127,9 @@ def broadcast_params(
)
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
bcast_params = [
broadcast_to_fn(param, shape)
for shape, param in zip(shapes, params, strict=False)
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
]
return bcast_params
......
......@@ -447,10 +447,8 @@ class SpecifyShape(COp):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
# strict=False because we are in a hot loop
if not all(
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
):
# zip strict not specified because we are in a hot loop
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
)
......
......@@ -261,10 +261,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that.",
)
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
if not all(
ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape, strict=False)
for ds, ts in zip(data.shape, self.shape)
):
raise TypeError(
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
......@@ -333,17 +333,14 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return False
def is_super(self, otype):
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
if (
isinstance(otype, type(self))
and otype.dtype == self.dtype
and otype.ndim == self.ndim
# `otype` is allowed to be as or more shape-specific than `self`,
# but not less
and all(
sb == ob or sb is None
for sb, ob in zip(self.shape, otype.shape, strict=False)
)
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
):
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论