提交 456cce1d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use static shape values in get_vector_length

上级 c2909c93
...@@ -80,8 +80,9 @@ def get_vector_length(v: TensorLike) -> int: ...@@ -80,8 +80,9 @@ def get_vector_length(v: TensorLike) -> int:
if v.type.ndim != 1: if v.type.ndim != 1:
raise TypeError(f"Argument must be a vector; got {v.type}") raise TypeError(f"Argument must be a vector; got {v.type}")
if v.type.broadcastable[0]: static_shape: Optional[int] = v.type.shape[0]
return 1 if static_shape is not None:
return static_shape
return _get_vector_length(getattr(v.owner, "op", v), v) return _get_vector_length(getattr(v.owner, "op", v), v)
......
...@@ -1177,6 +1177,8 @@ def test_get_vector_length(): ...@@ -1177,6 +1177,8 @@ def test_get_vector_length():
# Test `Alloc`s # Test `Alloc`s
assert 3 == get_vector_length(alloc(0, 3)) assert 3 == get_vector_length(alloc(0, 3))
assert 5 == get_vector_length(tensor(np.float64, shape=(5,)))
class TestJoinAndSplit: class TestJoinAndSplit:
# Split is tested by each verify_grad method. # Split is tested by each verify_grad method.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论