提交 4b41e092 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Add exceptions for hot loops

上级 54fba943
...@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables, strict=True): # strict=False because asserted above
for output, variable in zip(outputs, variables, strict=False):
output[0] = variable output[0] = variable
...@@ -1002,8 +1002,9 @@ class Function: ...@@ -1002,8 +1002,9 @@ class Function:
# if we are allowing garbage collection, remove the # if we are allowing garbage collection, remove the
# output reference from the internal storage cells # output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False): if getattr(self.vm, "allow_gc", False):
# strict=False because we are in a hot loop
for o_container, o_variable in zip( for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True self.output_storage, self.maker.fgraph.outputs, strict=False
): ):
if o_variable.owner is not None: if o_variable.owner is not None:
# this node is the variable of computation # this node is the variable of computation
...@@ -1012,8 +1013,9 @@ class Function: ...@@ -1012,8 +1013,9 @@ class Function:
if getattr(self.vm, "need_update_inputs", True): if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function # Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed( for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=True)) list(zip(self.maker.expanded_inputs, input_storage, strict=False))
): ):
if input.update is not None: if input.update is not None:
storage.data = outputs.pop() storage.data = outputs.pop()
...@@ -1044,7 +1046,8 @@ class Function: ...@@ -1044,7 +1046,8 @@ class Function:
assert len(self.output_keys) == len(outputs) assert len(self.output_keys) == len(outputs)
if output_subset is None: if output_subset is None:
return dict(zip(self.output_keys, outputs, strict=True)) # strict=False because we are in a hot loop
return dict(zip(self.output_keys, outputs, strict=False))
else: else:
return { return {
self.output_keys[index]: outputs[index] self.output_keys[index]: outputs[index]
...@@ -1111,8 +1114,9 @@ def _pickle_Function(f): ...@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
ins = list(f.input_storage) ins = list(f.input_storage)
input_storage = [] input_storage = []
# strict=False because we are in a hot loop
for (input, indices, inputs), (required, refeed, default) in zip( for (input, indices, inputs), (required, refeed, default) in zip(
f.indices, f.defaults, strict=True f.indices, f.defaults, strict=False
): ):
input_storage.append(ins[0]) input_storage.append(ins[0])
del ins[0] del ins[0]
......
...@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp): ...@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
for out, t in zip(outputs, input_true_branch, strict=True): # strict=False because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False):
compute_map[out][0] = 1 compute_map[out][0] = 1
val = storage_map[t][0] val = storage_map[t][0]
if self.as_view: if self.as_view:
...@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp): ...@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp):
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
for out, f in zip(outputs, inputs_false_branch, strict=True): # strict=False because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False):
compute_map[out][0] = 1 compute_map[out][0] = 1
# can't view both outputs unless destroyhandler # can't view both outputs unless destroyhandler
# improves # improves
......
...@@ -539,12 +539,14 @@ class WrapLinker(Linker): ...@@ -539,12 +539,14 @@ class WrapLinker(Linker):
def f(): def f():
for inputs in input_lists[1:]: for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs, strict=True): # strict=False because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False):
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for x in to_reset: for x in to_reset:
x[0] = None x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups) pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)): # strict=False because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
try: try:
wrapper(self.fgraph, i, node, *thunks) wrapper(self.fgraph, i, node, *thunks)
except Exception: except Exception:
...@@ -666,8 +668,9 @@ class JITLinker(PerformLinker): ...@@ -666,8 +668,9 @@ class JITLinker(PerformLinker):
): ):
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs]) outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
# strict=False because we are in a hot loop
for o_var, o_storage, o_val in zip( for o_var, o_storage, o_val in zip(
fgraph.outputs, thunk_outputs, outputs, strict=True fgraph.outputs, thunk_outputs, outputs, strict=False
): ):
compute_map[o_var][0] = True compute_map[o_var][0] = True
o_storage[0] = self.output_filter(o_var, o_val) o_storage[0] = self.output_filter(o_var, o_val)
......
...@@ -1993,25 +1993,26 @@ class DualLinker(Linker): ...@@ -1993,25 +1993,26 @@ class DualLinker(Linker):
) )
def f(): def f():
for input1, input2 in zip(i1, i2, strict=True): # strict=False because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False):
# Set the inputs to be the same in both branches. # Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to # The copy is necessary in order for inplace ops not to
# interfere. # interfere.
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip( for thunk1, thunk2, node1, node2 in zip(
thunks1, thunks2, order1, order2, strict=True thunks1, thunks2, order1, order2, strict=False
): ):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True): for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
if output in no_recycling: if output in no_recycling:
storage[0] = None storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True): for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
if output in no_recycling: if output in no_recycling:
storage[0] = None storage[0] = None
try: try:
thunk1() thunk1()
thunk2() thunk2()
for output1, output2 in zip( for output1, output2 in zip(
thunk1.outputs, thunk2.outputs, strict=True thunk1.outputs, thunk2.outputs, strict=False
): ):
self.checker(output1, output2) self.checker(output1, output2)
except Exception: except Exception:
......
...@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): ...@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else: else:
def py_perform_return(inputs): def py_perform_return(inputs):
# strict=False because we are in a hot loop
return tuple( return tuple(
out_type.filter(out[0]) out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=True) for out_type, out in zip(output_types, py_perform(inputs), strict=False)
) )
@numba_njit @numba_njit
......
...@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs): ...@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs):
def pytorch_funcify_SpecifyShape(op, node, **kwargs): def pytorch_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape): def specifyshape(x, *shape):
assert x.ndim == len(shape) assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape, strict=True): # strict=False because asserted above
for actual, expected in zip(x.shape, shape, strict=False):
if expected is None: if expected is None:
continue continue
if actual != expected: if actual != expected:
......
...@@ -190,8 +190,9 @@ def streamline( ...@@ -190,8 +190,9 @@ def streamline(
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
try: try:
# strict=False because we are in a hot loop
for thunk, node, old_storage in zip( for thunk, node, old_storage in zip(
thunks, order, post_thunk_old_storage, strict=True thunks, order, post_thunk_old_storage, strict=False
): ):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
...@@ -206,7 +207,8 @@ def streamline( ...@@ -206,7 +207,8 @@ def streamline(
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
try: try:
for thunk, node in zip(thunks, order, strict=True): # strict=False because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False):
thunk() thunk()
except Exception: except Exception:
raise_with_op(fgraph, node, thunk) raise_with_op(fgraph, node, thunk)
......
...@@ -1150,8 +1150,9 @@ class ScalarOp(COp): ...@@ -1150,8 +1150,9 @@ class ScalarOp(COp):
else: else:
variables = from_return_values(self.impl(*inputs)) variables = from_return_values(self.impl(*inputs))
assert len(variables) == len(output_storage) assert len(variables) == len(output_storage)
# strict=False because we are in a hot loop
for out, storage, variable in zip( for out, storage, variable in zip(
node.outputs, output_storage, variables, strict=True node.outputs, output_storage, variables, strict=False
): ):
dtype = out.dtype dtype = out.dtype
storage[0] = self._cast_scalar(variable, dtype) storage[0] = self._cast_scalar(variable, dtype)
...@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp): ...@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs) outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs, strict=True): # strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False):
storage[0] = out_val storage[0] = out_val
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
......
...@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp):
) )
else: else:
update = outputs update = outputs
for i, u in zip(init[: len(update)], update, strict=True): for i, u in zip(init, update, strict=False):
if i.type != u.type: if i.type != u.type:
raise TypeError( raise TypeError(
"Init and update types must be the same: " "Init and update types must be the same: "
...@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for i in range(n_steps): for i in range(n_steps):
carry = inner_fn(*carry, *constant) carry = inner_fn(*carry, *constant)
for storage, out_val in zip(output_storage, carry, strict=True): # strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False):
storage[0] = out_val storage[0] = out_val
@property @property
......
...@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if len(self.inner_outputs) != len(other.inner_outputs): if len(self.inner_outputs) != len(other.inner_outputs):
return False return False
# strict=False because length already compared above
for self_in, other_in in zip( for self_in, other_in in zip(
self.inner_inputs, other.inner_inputs, strict=True self.inner_inputs, other.inner_inputs, strict=False
): ):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False return False
......
...@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op): ...@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough # Make sure the output is big enough
out_s = [] out_s = []
for xdim, ydim in zip(x_s, y_s, strict=True): # strict=False because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False):
if xdim == ydim: if xdim == ydim:
outdim = xdim outdim = xdim
elif xdim == 1: elif xdim == 1:
......
...@@ -342,16 +342,17 @@ class Blockwise(Op): ...@@ -342,16 +342,17 @@ class Blockwise(Op):
def _check_runtime_broadcast(self, node, inputs): def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node) batch_ndim = self.batch_ndim(node)
# strict=False because we are in a hot loop
for dims_and_bcast in zip( for dims_and_bcast in zip(
*[ *[
zip( zip(
input.shape[:batch_ndim], input.shape[:batch_ndim],
sinput.type.broadcastable[:batch_ndim], sinput.type.broadcastable[:batch_ndim],
strict=True, strict=False,
) )
for input, sinput in zip(inputs, node.inputs, strict=True) for input, sinput in zip(inputs, node.inputs, strict=False)
], ],
strict=True, strict=False,
): ):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError( raise ValueError(
...@@ -374,8 +375,9 @@ class Blockwise(Op): ...@@ -374,8 +375,9 @@ class Blockwise(Op):
if not isinstance(res, tuple): if not isinstance(res, tuple):
res = (res,) res = (res,)
# strict=False because we are in a hot loop
for node_out, out_storage, r in zip( for node_out, out_storage, r in zip(
node.outputs, output_storage, res, strict=True node.outputs, output_storage, res, strict=False
): ):
out_dtype = getattr(node_out, "dtype", None) out_dtype = getattr(node_out, "dtype", None)
if out_dtype and out_dtype != r.dtype: if out_dtype and out_dtype != r.dtype:
......
...@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp): ...@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp):
if nout == 1: if nout == 1:
variables = [variables] variables = [variables]
# strict=False because we are in a hot loop
for i, (variable, storage, nout) in enumerate( for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs, strict=True) zip(variables, output_storage, node.outputs, strict=False)
): ):
storage[0] = variable = np.asarray(variable, dtype=nout.dtype) storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
...@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp): ...@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp):
@staticmethod @staticmethod
def _check_runtime_broadcast(node, inputs): def _check_runtime_broadcast(node, inputs):
# strict=False because we are in a hot loop
for dims_and_bcast in zip( for dims_and_bcast in zip(
*[ *[
zip(input.shape, sinput.type.broadcastable, strict=False) zip(input.shape, sinput.type.broadcastable, strict=False)
for input, sinput in zip(inputs, node.inputs, strict=True) for input, sinput in zip(inputs, node.inputs, strict=False)
], ],
strict=True, strict=False,
): ):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError( raise ValueError(
......
...@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable): ...@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below. # to `p.shape[:-1]` in the call to `vsearchsorted` below.
if len(size) < (p.ndim - 1): if len(size) < (p.ndim - 1):
raise ValueError("`size` is incompatible with the shape of `p`") raise ValueError("`size` is incompatible with the shape of `p`")
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True): # strict=False because we are in a hot loop
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
if s == 1 and ps != 1: if s == 1 and ps != 1:
raise ValueError("`size` is incompatible with the shape of `p`") raise ValueError("`size` is incompatible with the shape of `p`")
......
...@@ -44,7 +44,8 @@ def params_broadcast_shapes( ...@@ -44,7 +44,8 @@ def params_broadcast_shapes(
max_fn = maximum if use_pytensor else max max_fn = maximum if use_pytensor else max
rev_extra_dims: list[int] = [] rev_extra_dims: list[int] = []
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True): # strict=False because we are in a hot loop
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
# We need this in order to use `len` # We need this in order to use `len`
param_shape = tuple(param_shape) param_shape = tuple(param_shape)
extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
...@@ -63,11 +64,12 @@ def params_broadcast_shapes( ...@@ -63,11 +64,12 @@ def params_broadcast_shapes(
extra_dims = tuple(reversed(rev_extra_dims)) extra_dims = tuple(reversed(rev_extra_dims))
# strict=False because we are in a hot loop
bcast_shapes = [ bcast_shapes = [
(extra_dims + tuple(param_shape)[-ndim_param:]) (extra_dims + tuple(param_shape)[-ndim_param:])
if ndim_param > 0 if ndim_param > 0
else extra_dims else extra_dims
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True) for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
] ]
return bcast_shapes return bcast_shapes
...@@ -110,10 +112,11 @@ def broadcast_params( ...@@ -110,10 +112,11 @@ def broadcast_params(
use_pytensor = False use_pytensor = False
param_shapes = [] param_shapes = []
for p in params: for p in params:
# strict=False because we are in a hot loop
param_shape = tuple( param_shape = tuple(
1 if bcast else s 1 if bcast else s
for s, bcast in zip( for s, bcast in zip(
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False
) )
) )
use_pytensor |= isinstance(p, Variable) use_pytensor |= isinstance(p, Variable)
...@@ -124,9 +127,10 @@ def broadcast_params( ...@@ -124,9 +127,10 @@ def broadcast_params(
) )
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
# strict=False because we are in a hot loop
bcast_params = [ bcast_params = [
broadcast_to_fn(param, shape) broadcast_to_fn(param, shape)
for shape, param in zip(shapes, params, strict=True) for shape, param in zip(shapes, params, strict=False)
] ]
return bcast_params return bcast_params
......
...@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node): ...@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val # Slices to take from val
val_slices = [] val_slices = []
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)): for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
# If val was not copied over that dim, # If val was not copied over that dim,
# we need to take the appropriate subtensor on it. # we need to take the appropriate subtensor on it.
if i >= n_added_dims: if i >= n_added_dims:
......
...@@ -448,8 +448,9 @@ class SpecifyShape(COp): ...@@ -448,8 +448,9 @@ class SpecifyShape(COp):
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
) )
# strict=False because we are in a hot loop
if not all( if not all(
xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
): ):
raise AssertionError( raise AssertionError(
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
...@@ -578,15 +579,12 @@ def specify_shape( ...@@ -578,15 +579,12 @@ def specify_shape(
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore] x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12. # The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12. # Thus we need to ignore unused-ignore on 3.12.
new_shape_info = any(
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError # If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if len(shape) != x.type.ndim: if not new_shape_info and len(shape) == x.type.ndim:
return _specify_shape(x, *shape)
new_shape_matches = all(
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
)
if new_shape_matches:
return x return x
return _specify_shape(x, *shape) return _specify_shape(x, *shape)
......
...@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that.", " PyTensor C code does not support that.",
) )
# strict=False because we are in a hot loop
if not all( if not all(
ds == ts if ts is not None else True ds == ts if ts is not None else True
for ds, ts in zip(data.shape, self.shape, strict=True) for ds, ts in zip(data.shape, self.shape, strict=False)
): ):
raise TypeError( raise TypeError(
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
...@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return False return False
def is_super(self, otype): def is_super(self, otype):
# strict=False because we are in a hot loop
if ( if (
isinstance(otype, type(self)) isinstance(otype, type(self))
and otype.dtype == self.dtype and otype.dtype == self.dtype
...@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
# but not less # but not less
and all( and all(
sb == ob or sb is None sb == ob or sb is None
for sb, ob in zip(self.shape, otype.shape, strict=True) for sb, ob in zip(self.shape, otype.shape, strict=False)
) )
): ):
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论