提交 231e51f0 authored 作者: Frederic's avatar Frederic

Make get_scalar_constant_value() recurse less.

上级 e8d45dc8
......@@ -532,7 +532,7 @@ get_scalar_constant_value_elemwises = (
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv)
def get_scalar_constant_value(v):
def get_scalar_constant_value(orig_v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
......@@ -544,7 +544,8 @@ def get_scalar_constant_value(v):
:note: There may be another function similar to this one in the
code, but I'm not sure where it is.
"""
if True:
v = orig_v
while True:
if v is None:
# None is not a scalar (and many uses of this function seem to depend
# on passing it None)
......@@ -567,7 +568,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard,
compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0])
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, theano.compile.ops.Shape_i):
if isinstance(v.owner.inputs[0], Constant):
return v.owner.inputs[0].data.shape[v.owner.op.i]
......@@ -580,7 +582,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
v = val
continue
if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i)
for i in v.owner.inputs]
......@@ -591,7 +594,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op.scalar_op, scal.Second):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
v = val
continue
elif isinstance(v.owner.op.scalar_op,
get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i) for i in v.owner.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论