Unverified 提交 71962518 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Implement _get_vector_length for Alloc (#817)

上级 630b5574
......@@ -1669,6 +1669,14 @@ alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
@_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var):
try:
return get_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
def full(shape, fill_value, dtype=None):
"""Return a new array of given shape and type, filled with `fill_value`.
......
......@@ -1172,6 +1172,9 @@ def test_get_vector_length():
assert np.allclose(f(4, 5), [5, 9, 4])
# Test `Alloc`s
assert 3 == get_vector_length(alloc(0, 3))
class TestJoinAndSplit:
# Split is tested by each verify_grad method.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论