提交 de92edc0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove impossible condition in get_scalar_constant_value logic

上级 84af50df
......@@ -441,23 +441,9 @@ def get_scalar_constant_value(
and isinstance(v.owner.inputs[0].owner.op, Join)
and len(v.owner.op.idx_list) == 1
):
# Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the
# one used in the sub-tensor).
if builtins.all(
var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Note the '+ 1' is because the first argument to Join
# is the axis.
ret = v.owner.inputs[0].owner.inputs[idx + 1]
ret = get_scalar_constant_value(ret, max_recur=max_recur)
# join can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)
# Ensure the Join is joining only (effectively) scalar
# variables (so that the constant value can be found at the
# same index as the one used in the sub-tensor).
if builtins.all(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论