提交 402dcf88 authored 作者: Frederic Bastien's avatar Frederic Bastien

Convert some function recursion by loop recursion.

上级 e7fe1324
...@@ -651,8 +651,8 @@ def get_scalar_constant_value(orig_v, elemwise=True, ...@@ -651,8 +651,8 @@ def get_scalar_constant_value(orig_v, elemwise=True,
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)): elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)):
return get_scalar_constant_value(v.owner.inputs[0], v = v.owner.inputs[0]
max_recur=max_recur) continue
elif isinstance(v.owner.op, scal.ScalarOp): elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
...@@ -732,9 +732,8 @@ def get_scalar_constant_value(orig_v, elemwise=True, ...@@ -732,9 +732,8 @@ def get_scalar_constant_value(orig_v, elemwise=True,
for joined in v.owner.inputs[0].owner.inputs[1:]: for joined in v.owner.inputs[0].owner.inputs[1:]:
ll = get_vector_length(joined) ll = get_vector_length(joined)
if idx < length + ll: if idx < length + ll:
return get_scalar_constant_value( v = joined[idx - length]
joined[idx - length], continue
max_recur=max_recur)
length += ll length += ll
except TypeError: except TypeError:
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论