Unverified 提交 584c0c15 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #305 from brandonwillard/fix-split-errors

Fix Split.__str__ and adjust its error messages
上级 bd54469c
...@@ -1844,7 +1844,7 @@ class Split(COp): ...@@ -1844,7 +1844,7 @@ class Split(COp):
self.len_splits = int(len_splits) self.len_splits = int(len_splits)
def __str__(self): def __str__(self):
return "{self.__class__.__name__ }{{{self.len_splits}}}" return f"{self.__class__.__name__ }{{{self.len_splits}}}"
def make_node(self, x, axis, splits): def make_node(self, x, axis, splits):
"""WRITEME""" """WRITEME"""
...@@ -1853,9 +1853,9 @@ class Split(COp): ...@@ -1853,9 +1853,9 @@ class Split(COp):
splits = as_tensor_variable(splits) splits = as_tensor_variable(splits)
if splits.type not in int_vector_types: if splits.type not in int_vector_types:
raise TypeError("splits must have type tensor.lvector", splits.type) raise TypeError("`splits` parameter must be tensors of integer type")
if axis.type not in int_types: if axis.type not in int_types:
raise TypeError("axis must have type lscalar", axis.type) raise TypeError("`axis` parameter must be an integer scalar")
# # The following lines are necessary if we allow splits of zero # # The following lines are necessary if we allow splits of zero
# if isinstance(axis, Constant): # if isinstance(axis, Constant):
...@@ -1869,30 +1869,19 @@ class Split(COp): ...@@ -1869,30 +1869,19 @@ class Split(COp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
"""WRITEME"""
x, axis, splits = inputs x, axis, splits = inputs
try:
len_along_axis = x.shape[axis] len_along_axis = x.shape[axis]
except Exception:
raise ValueError(
f"Split.perform() with axis=({axis}) is invalid"
f" for x.shape==({x.shape})"
)
if len(splits) != self.len_splits:
raise ValueError(
"In Split.perform(), len(splits) != len_splits.",
(len(splits), self.len_splits),
)
if len(splits) != self.len_splits:
raise ValueError("Length of `splits` is not equal to `len_splits`")
if np.sum(splits) != len_along_axis: if np.sum(splits) != len_along_axis:
raise ValueError( raise ValueError(
f"The splits sum to {np.sum(splits)}, expected {len_along_axis}" f"The splits sum to {np.sum(splits)}; expected {len_along_axis}"
) )
if builtins.any([nb < 0 for nb in splits]): if builtins.any([nb < 0 for nb in splits]):
raise ValueError( raise ValueError(
"Split: you tried to make an ndarray with a " "Attempted to make an array with a " "negative number of elements"
"negative number of elements."
) )
# Checking is done, let's roll the splitting algorithm! # Checking is done, let's roll the splitting algorithm!
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论