提交 4b41e092 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Add exceptions for hot loops

上级 54fba943
......@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph):
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables, strict=True):
# strict=False because asserted above
for output, variable in zip(outputs, variables, strict=False):
output[0] = variable
......@@ -1002,8 +1002,9 @@ class Function:
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
# strict=False because we are in a hot loop
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True
self.output_storage, self.maker.fgraph.outputs, strict=False
):
if o_variable.owner is not None:
# this node is the variable of computation
......@@ -1012,8 +1013,9 @@ class Function:
if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
):
if input.update is not None:
storage.data = outputs.pop()
......@@ -1044,7 +1046,8 @@ class Function:
assert len(self.output_keys) == len(outputs)
if output_subset is None:
return dict(zip(self.output_keys, outputs, strict=True))
# strict=False because we are in a hot loop
return dict(zip(self.output_keys, outputs, strict=False))
else:
return {
self.output_keys[index]: outputs[index]
......@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
ins = list(f.input_storage)
input_storage = []
# strict=False because we are in a hot loop
for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults, strict=True
f.indices, f.defaults, strict=False
):
input_storage.append(ins[0])
del ins[0]
......
......@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0:
return ls
else:
for out, t in zip(outputs, input_true_branch, strict=True):
# strict=False because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False):
compute_map[out][0] = 1
val = storage_map[t][0]
if self.as_view:
......@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0:
return ls
else:
for out, f in zip(outputs, inputs_false_branch, strict=True):
# strict=False because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False):
compute_map[out][0] = 1
# can't view both outputs unless destroyhandler
# improves
......
......@@ -539,12 +539,14 @@ class WrapLinker(Linker):
def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs, strict=True):
# strict=False because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
# strict=False because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
......@@ -666,8 +668,9 @@ class JITLinker(PerformLinker):
):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=True
fgraph.outputs, thunk_outputs, outputs, strict=False
):
compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val)
......
......@@ -1993,25 +1993,26 @@ class DualLinker(Linker):
)
def f():
for input1, input2 in zip(i1, i2, strict=True):
# strict=False because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip(
thunks1, thunks2, order1, order2, strict=True
thunks1, thunks2, order1, order2, strict=False
):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
if output in no_recycling:
storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
if output in no_recycling:
storage[0] = None
try:
thunk1()
thunk2()
for output1, output2 in zip(
thunk1.outputs, thunk2.outputs, strict=True
thunk1.outputs, thunk2.outputs, strict=False
):
self.checker(output1, output2)
except Exception:
......
......@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else:
def py_perform_return(inputs):
# strict=False because we are in a hot loop
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
)
@numba_njit
......
......@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs):
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape, strict=True):
# strict=False because asserted above
for actual, expected in zip(x.shape, shape, strict=False):
if expected is None:
continue
if actual != expected:
......
......@@ -190,8 +190,9 @@ def streamline(
for x in no_recycling:
x[0] = None
try:
# strict=False because we are in a hot loop
for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage, strict=True
thunks, order, post_thunk_old_storage, strict=False
):
thunk()
for old_s in old_storage:
......@@ -206,7 +207,8 @@ def streamline(
for x in no_recycling:
x[0] = None
try:
for thunk, node in zip(thunks, order, strict=True):
# strict=False because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
......
......@@ -1150,8 +1150,9 @@ class ScalarOp(COp):
else:
variables = from_return_values(self.impl(*inputs))
assert len(variables) == len(output_storage)
# strict=False because we are in a hot loop
for out, storage, variable in zip(
node.outputs, output_storage, variables, strict=True
node.outputs, output_storage, variables, strict=False
):
dtype = out.dtype
storage[0] = self._cast_scalar(variable, dtype)
......@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp):
def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs, strict=True):
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False):
storage[0] = out_val
def grad(self, inputs, output_grads):
......
......@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp):
)
else:
update = outputs
for i, u in zip(init[: len(update)], update, strict=True):
for i, u in zip(init, update, strict=False):
if i.type != u.type:
raise TypeError(
"Init and update types must be the same: "
......@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for i in range(n_steps):
carry = inner_fn(*carry, *constant)
for storage, out_val in zip(output_storage, carry, strict=True):
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False):
storage[0] = out_val
@property
......
......@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if len(self.inner_outputs) != len(other.inner_outputs):
return False
# strict=False because length already compared above
for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=True
self.inner_inputs, other.inner_inputs, strict=False
):
if self_in.type != other_in.type:
return False
......
......@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough
out_s = []
for xdim, ydim in zip(x_s, y_s, strict=True):
# strict=False because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
......
......@@ -342,16 +342,17 @@ class Blockwise(Op):
def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
# strict=False because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(
input.shape[:batch_ndim],
sinput.type.broadcastable[:batch_ndim],
strict=True,
strict=False,
)
for input, sinput in zip(inputs, node.inputs, strict=True)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=True,
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
......@@ -374,8 +375,9 @@ class Blockwise(Op):
if not isinstance(res, tuple):
res = (res,)
# strict=False because we are in a hot loop
for node_out, out_storage, r in zip(
node.outputs, output_storage, res, strict=True
node.outputs, output_storage, res, strict=False
):
out_dtype = getattr(node_out, "dtype", None)
if out_dtype and out_dtype != r.dtype:
......
......@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp):
if nout == 1:
variables = [variables]
# strict=False because we are in a hot loop
for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs, strict=True)
zip(variables, output_storage, node.outputs, strict=False)
):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
......@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp):
@staticmethod
def _check_runtime_broadcast(node, inputs):
# strict=False because we are in a hot loop
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable, strict=False)
for input, sinput in zip(inputs, node.inputs, strict=True)
for input, sinput in zip(inputs, node.inputs, strict=False)
],
strict=True,
strict=False,
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
......
......@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`")
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True):
# strict=False because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`")
......
......@@ -44,7 +44,8 @@ def params_broadcast_shapes(
max_fn = maximum if use_pytensor else max
rev_extra_dims: list[int] = []
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True):
# strict=False because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
# We need this in order to use `len`
param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
......@@ -63,11 +64,12 @@ def params_broadcast_shapes(
extra_dims = tuple(reversed(rev_extra_dims))
# strict=False because we are in a hot loop
bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0
else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True)
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
]
return bcast_shapes
......@@ -110,10 +112,11 @@ def broadcast_params(
use_pytensor = False
param_shapes = []
for p in params:
# strict=False because we are in a hot loop
param_shape = tuple(
1 if bcast else s
for s, bcast in zip(
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False
)
)
use_pytensor |= isinstance(p, Variable)
......@@ -124,9 +127,10 @@ def broadcast_params(
)
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
# strict=False because we are in a hot loop
bcast_params = [
broadcast_to_fn(param, shape)
for shape, param in zip(shapes, params, strict=True)
for shape, param in zip(shapes, params, strict=False)
]
return bcast_params
......
......@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
val_slices = []
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if i >= n_added_dims:
......
......@@ -448,8 +448,9 @@ class SpecifyShape(COp):
raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
# strict=False because we are in a hot loop
if not all(
xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
):
raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
......@@ -578,15 +579,12 @@ def specify_shape(
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if len(shape) != x.type.ndim:
return _specify_shape(x, *shape)
new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
if not new_shape_info and len(shape) == x.type.ndim:
return x
return _specify_shape(x, *shape)
......
......@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that.",
)
# strict=False because we are in a hot loop
if not all(
ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape, strict=True)
for ds, ts in zip(data.shape, self.shape, strict=False)
):
raise TypeError(
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
......@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return False
def is_super(self, otype):
# strict=False because we are in a hot loop
if (
isinstance(otype, type(self))
and otype.dtype == self.dtype
......@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
# but not less
and all(
sb == ob or sb is None
for sb, ob in zip(self.shape, otype.shape, strict=True)
for sb, ob in zip(self.shape, otype.shape, strict=False)
)
):
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论