提交 de92edc0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove impossible condition in get_scalar_constant_value logic

上级 84af50df
...@@ -441,23 +441,9 @@ def get_scalar_constant_value( ...@@ -441,23 +441,9 @@ def get_scalar_constant_value(
and isinstance(v.owner.inputs[0].owner.op, Join) and isinstance(v.owner.inputs[0].owner.op, Join)
and len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1
): ):
# Ensure the Join is joining only scalar variables (so that # Ensure the Join is joining only (effectively) scalar
# the constant value can be found at the same index as the # variables (so that the constant value can be found at the
# one used in the sub-tensor). # same index as the one used in the sub-tensor).
if builtins.all(
var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Note the '+ 1' is because the first argument to Join
# is the axis.
ret = v.owner.inputs[0].owner.inputs[idx + 1]
ret = get_scalar_constant_value(ret, max_recur=max_recur)
# join can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)
if builtins.all( if builtins.all(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论