提交 1f7a2686 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add error message in Numba implementation of SpecifyShape

上级 f15258d9
...@@ -36,11 +36,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): ...@@ -36,11 +36,11 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
func_conditions = [ func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}" f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'"
for i, (shape_input, shape_input_names) in enumerate( for i, (node_dim_input, eval_dim_name) in enumerate(
zip(shape_inputs, shape_input_names, strict=True) zip(shape_inputs, shape_input_names, strict=True)
) )
if shape_input is not NoneConst if node_dim_input is not NoneConst
] ]
func = dedent( func = dedent(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论