提交 10105bea authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Don't specify zip strict kwarg in hot loops

It seems to add a non-trivial 100ns
上级 5335a680
...@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"] ...@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format = true docstring-code-format = true
[tool.ruff.lint] [tool.ruff.lint]
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
unfixable = [ unfixable = [
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
......
...@@ -873,7 +873,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -873,7 +873,6 @@ class OpFromGraph(Op, HasInnerGraph):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) # zip strict not specified because we are in a hot loop
# strict=False because asserted above for output, variable in zip(outputs, variables):
for output, variable in zip(outputs, variables, strict=False):
output[0] = variable output[0] = variable
...@@ -924,7 +924,8 @@ class Function: ...@@ -924,7 +924,8 @@ class Function:
# Reinitialize each container's 'provided' counter # Reinitialize each container's 'provided' counter
if trust_input: if trust_input:
for arg_container, arg in zip(input_storage, args, strict=False): # zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
arg_container.storage[0] = arg arg_container.storage[0] = arg
else: else:
for arg_container in input_storage: for arg_container in input_storage:
...@@ -934,7 +935,8 @@ class Function: ...@@ -934,7 +935,8 @@ class Function:
raise TypeError("Too many parameter passed to pytensor function") raise TypeError("Too many parameter passed to pytensor function")
# Set positional arguments # Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False): # zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
# See discussion about None as input # See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None: if arg is None:
......
...@@ -305,8 +305,8 @@ class IfElse(_NoPythonOp): ...@@ -305,8 +305,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False): for out, t in zip(outputs, input_true_branch):
compute_map[out][0] = 1 compute_map[out][0] = 1
val = storage_map[t][0] val = storage_map[t][0]
if self.as_view: if self.as_view:
...@@ -326,8 +326,8 @@ class IfElse(_NoPythonOp): ...@@ -326,8 +326,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False): for out, f in zip(outputs, inputs_false_branch):
compute_map[out][0] = 1 compute_map[out][0] = 1
# can't view both outputs unless destroyhandler # can't view both outputs unless destroyhandler
# improves # improves
......
...@@ -539,14 +539,14 @@ class WrapLinker(Linker): ...@@ -539,14 +539,14 @@ class WrapLinker(Linker):
def f(): def f():
for inputs in input_lists[1:]: for inputs in input_lists[1:]:
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False): for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for x in to_reset: for x in to_reset:
x[0] = None x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups) pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(self.fgraph, i, node, *thunks) wrapper(self.fgraph, i, node, *thunks)
except Exception: except Exception:
...@@ -668,8 +668,8 @@ class JITLinker(PerformLinker): ...@@ -668,8 +668,8 @@ class JITLinker(PerformLinker):
# since the error may come from any of them? # since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk) raise_with_op(self.fgraph, output_nodes[0], thunk)
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val o_storage[0] = o_val
thunk.inputs = thunk_inputs thunk.inputs = thunk_inputs
......
...@@ -1988,27 +1988,23 @@ class DualLinker(Linker): ...@@ -1988,27 +1988,23 @@ class DualLinker(Linker):
) )
def f(): def f():
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False): for input1, input2 in zip(i1, i2):
# Set the inputs to be the same in both branches. # Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to # The copy is necessary in order for inplace ops not to
# interfere. # interfere.
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip( for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
thunks1, thunks2, order1, order2, strict=False for output, storage in zip(node1.outputs, thunk1.outputs):
):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
if output in no_recycling: if output in no_recycling:
storage[0] = None storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False): for output, storage in zip(node2.outputs, thunk2.outputs):
if output in no_recycling: if output in no_recycling:
storage[0] = None storage[0] = None
try: try:
thunk1() thunk1()
thunk2() thunk2()
for output1, output2 in zip( for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
thunk1.outputs, thunk2.outputs, strict=False
):
self.checker(output1, output2) self.checker(output1, output2)
except Exception: except Exception:
raise_with_op(fgraph, node1) raise_with_op(fgraph, node1)
......
...@@ -312,10 +312,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): ...@@ -312,10 +312,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else: else:
def py_perform_return(inputs): def py_perform_return(inputs):
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
return tuple( return tuple(
out_type.filter(out[0]) out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=False) for out_type, out in zip(output_types, py_perform(inputs))
) )
@numba_njit @numba_njit
......
...@@ -166,10 +166,7 @@ class _CythonWrapper(numba.types.WrapperAddressProtocol): ...@@ -166,10 +166,7 @@ class _CythonWrapper(numba.types.WrapperAddressProtocol):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# no strict argument because of the JIT # no strict argument because of the JIT
# TODO: check # TODO: check
args = [ args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
dtype(arg)
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
]
if self.has_pyx_skip_dispatch(): if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs) output = self._pyfunc(*args[:-1], **kwargs)
else: else:
......
...@@ -186,7 +186,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs): ...@@ -186,7 +186,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
new_arr = arr.T.astype(np.float64).copy() new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr): for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it # no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905 for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d: if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d) mode_fn(new_arr, i, j, v, d)
......
...@@ -183,7 +183,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): ...@@ -183,7 +183,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
r, c = 0, 0 r, c = 0, 0
# no strict argument because it is incompatible with numba # no strict argument because it is incompatible with numba
for arr, shape in zip(arrs, shapes): # noqa: B905 for arr, shape in zip(arrs, shapes):
rr, cc = shape rr, cc = shape
out[r : r + rr, c : c + cc] = arr out[r : r + rr, c : c + cc] = arr
r += rr r += rr
......
...@@ -219,7 +219,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -219,7 +219,7 @@ def numba_funcify_multiple_integer_vector_indexing(
shape_aft = x_shape[after_last_axis:] shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft) out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype) out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer return out_buffer
...@@ -253,7 +253,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -253,7 +253,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]): for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] = y[(*outer, i)] out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out return out
...@@ -275,7 +275,7 @@ def numba_funcify_multiple_integer_vector_indexing( ...@@ -275,7 +275,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
for outer in np.ndindex(x_shape[:first_axis]): for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] += y[(*outer, i)] out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out return out
...@@ -314,7 +314,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -314,7 +314,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if not len(idxs) == len(vals): if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.") raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba # no strict argument because incompatible with numba
for idx, val in zip(idxs, vals): # noqa: B905 for idx, val in zip(idxs, vals):
x[idx] = val x[idx] = val
return x return x
else: else:
...@@ -342,7 +342,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -342,7 +342,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
raise ValueError("The number of indices and values must match.") raise ValueError("The number of indices and values must match.")
# no strict argument because unsupported by numba # no strict argument because unsupported by numba
# TODO: this doesn't come up in tests # TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals): # noqa: B905 for idx, val in zip(idxs, vals):
x[idx] += val x[idx] += val
return x return x
......
...@@ -207,8 +207,8 @@ def streamline( ...@@ -207,8 +207,8 @@ def streamline(
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
try: try:
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False): for thunk, node in zip(thunks, order):
thunk() thunk()
except Exception: except Exception:
raise_with_op(fgraph, node, thunk) raise_with_op(fgraph, node, thunk)
......
...@@ -4416,8 +4416,8 @@ class Composite(ScalarInnerGraphOp): ...@@ -4416,8 +4416,8 @@ class Composite(ScalarInnerGraphOp):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs) outputs = self.py_perform_fn(*inputs)
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False): for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val storage[0] = out_val
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
......
...@@ -196,8 +196,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -196,8 +196,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for i in range(n_steps): for i in range(n_steps):
carry = inner_fn(*carry, *constant) carry = inner_fn(*carry, *constant)
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False): for storage, out_val in zip(output_storage, carry):
storage[0] = out_val storage[0] = out_val
@property @property
......
...@@ -3589,8 +3589,8 @@ class PermuteRowElements(Op): ...@@ -3589,8 +3589,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough # Make sure the output is big enough
out_s = [] out_s = []
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False): for xdim, ydim in zip(x_s, y_s):
if xdim == ydim: if xdim == ydim:
outdim = xdim outdim = xdim
elif xdim == 1: elif xdim == 1:
......
...@@ -712,9 +712,9 @@ class Elemwise(OpenMPOp): ...@@ -712,9 +712,9 @@ class Elemwise(OpenMPOp):
if nout == 1: if nout == 1:
variables = [variables] variables = [variables]
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for i, (variable, storage, nout) in enumerate( for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs, strict=False) zip(variables, output_storage, node.outputs)
): ):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype) storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
...@@ -729,11 +729,11 @@ class Elemwise(OpenMPOp): ...@@ -729,11 +729,11 @@ class Elemwise(OpenMPOp):
@staticmethod @staticmethod
def _check_runtime_broadcast(node, inputs): def _check_runtime_broadcast(node, inputs):
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for dims_and_bcast in zip( for dims_and_bcast in zip(
*[ *[
zip(input.shape, sinput.type.broadcastable, strict=False) zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs, strict=False) for input, sinput in zip(inputs, node.inputs)
], ],
strict=False, strict=False,
): ):
......
...@@ -1865,8 +1865,8 @@ class CategoricalRV(RandomVariable): ...@@ -1865,8 +1865,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below. # to `p.shape[:-1]` in the call to `vsearchsorted` below.
if len(size) < (p.ndim - 1): if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`") raise ValueError("`size` is incompatible with the shape of `p`")
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False): for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
if s == 1 and ps != 1: if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`") raise ValueError("`size` is incompatible with the shape of `p`")
......
...@@ -44,8 +44,8 @@ def params_broadcast_shapes( ...@@ -44,8 +44,8 @@ def params_broadcast_shapes(
max_fn = maximum if use_pytensor else max max_fn = maximum if use_pytensor else max
rev_extra_dims: list[int] = [] rev_extra_dims: list[int] = []
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False): for ndim_param, param_shape in zip(ndims_params, param_shapes):
# We need this in order to use `len` # We need this in order to use `len`
param_shape = tuple(param_shape) param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
...@@ -64,12 +64,12 @@ def params_broadcast_shapes( ...@@ -64,12 +64,12 @@ def params_broadcast_shapes(
extra_dims = tuple(reversed(rev_extra_dims)) extra_dims = tuple(reversed(rev_extra_dims))
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
bcast_shapes = [ bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:]) (extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0 if ndim_param > 0
else extra_dims else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False) for ndim_param, param_shape in zip(ndims_params, param_shapes)
] ]
return bcast_shapes return bcast_shapes
...@@ -127,10 +127,9 @@ def broadcast_params( ...@@ -127,10 +127,9 @@ def broadcast_params(
) )
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
bcast_params = [ bcast_params = [
broadcast_to_fn(param, shape) broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
for shape, param in zip(shapes, params, strict=False)
] ]
return bcast_params return bcast_params
......
...@@ -447,10 +447,8 @@ class SpecifyShape(COp): ...@@ -447,10 +447,8 @@ class SpecifyShape(COp):
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
) )
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
if not all( if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
):
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
) )
......
...@@ -261,10 +261,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -261,10 +261,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that.", " PyTensor C code does not support that.",
) )
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
if not all( if not all(
ds == ts if ts is not None else True ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape, strict=False) for ds, ts in zip(data.shape, self.shape)
): ):
raise TypeError( raise TypeError(
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
...@@ -333,17 +333,14 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -333,17 +333,14 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return False return False
def is_super(self, otype): def is_super(self, otype):
# strict=False because we are in a hot loop # zip strict not specified because we are in a hot loop
if ( if (
isinstance(otype, type(self)) isinstance(otype, type(self))
and otype.dtype == self.dtype and otype.dtype == self.dtype
and otype.ndim == self.ndim and otype.ndim == self.ndim
# `otype` is allowed to be as or more shape-specific than `self`, # `otype` is allowed to be as or more shape-specific than `self`,
# but not less # but not less
and all( and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
sb == ob or sb is None
for sb, ob in zip(self.shape, otype.shape, strict=False)
)
): ):
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论