提交 456cce1d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use static shape values in get_vector_length

上级 c2909c93
......@@ -80,8 +80,9 @@ def get_vector_length(v: TensorLike) -> int:
if v.type.ndim != 1:
raise TypeError(f"Argument must be a vector; got {v.type}")
if v.type.broadcastable[0]:
return 1
static_shape: Optional[int] = v.type.shape[0]
if static_shape is not None:
return static_shape
return _get_vector_length(getattr(v.owner, "op", v), v)
......
......@@ -1177,6 +1177,8 @@ def test_get_vector_length():
# Test `Alloc`s
assert 3 == get_vector_length(alloc(0, 3))
assert 5 == get_vector_length(tensor(np.float64, shape=(5,)))
class TestJoinAndSplit:
# Split is tested by each verify_grad method.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论