提交 7f6676d6 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup Split methods

上级 d50db11c
...@@ -2226,21 +2226,21 @@ class Split(COp): ...@@ -2226,21 +2226,21 @@ class Split(COp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs_storage):
x, axis, splits = inputs x, axis, splits = inputs
if len(splits) != self.len_splits: if len(splits) != self.len_splits:
raise ValueError("Length of splits is not equal to n_splits") raise ValueError("Length of splits is not equal to n_splits")
if np.sum(splits) != x.shape[axis]: if splits.sum() != x.shape[axis]:
raise ValueError( raise ValueError(
f"Split sizes sum to {np.sum(splits)}; expected {x.shape[axis]}" f"Split sizes sum to {splits.sum()}; expected {x.shape[axis]}"
) )
if np.any(splits < 0): if (splits < 0).any():
raise ValueError("Split sizes cannot be negative") raise ValueError("Split sizes cannot be negative")
split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis) split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs): for out_storage, out in zip(outputs_storage, split_outs, strict=False):
outputs[i][0] = out out_storage[0] = out
def infer_shape(self, fgraph, node, in_shapes): def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1] axis = node.inputs[1]
...@@ -2254,10 +2254,10 @@ class Split(COp): ...@@ -2254,10 +2254,10 @@ class Split(COp):
out_shapes.append(temp) out_shapes.append(temp)
return out_shapes return out_shapes
def grad(self, inputs, g_outputs): def L_op(self, inputs, outputs, g_outputs):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
x, axis, n = inputs x, axis, n = inputs
outputs = self(*inputs, return_list=True)
# If all the output gradients are disconnected, then so are the inputs # If all the output gradients are disconnected, then so are the inputs
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs): if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
return [ return [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论