提交 072d5d33 authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow to give name and dtype when doing scalar constant.

上级 dd192e13
......@@ -277,10 +277,10 @@ def convert(x, dtype=None):
return x_
def constant(x):
x = convert(x)
def constant(x, name=None, dtype=None):
x = convert(x, dtype=dtype)
assert x.ndim == 0
return ScalarConstant(get_scalar_type(str(x.dtype)), x)
return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name)
class Scalar(Type):
......
......@@ -488,5 +488,14 @@ def test_grad_abs():
# in test_fusion, TestCompositeCodegen
def test_constant():
c = constant(2, name='a')
assert c.name == 'a'
assert c.dtype == 'int8'
c = constant(2, dtype='float32')
assert c.name is None
assert c.dtype == 'float32'
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论