提交 231e51f0 authored 作者: Frederic's avatar Frederic

Make get_scalar_constant_value() recurse less.

上级 e8d45dc8
...@@ -532,7 +532,7 @@ get_scalar_constant_value_elemwises = ( ...@@ -532,7 +532,7 @@ get_scalar_constant_value_elemwises = (
scal.LT, scal.GT, scal.LE, scal.GE, scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul, scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv) scal.IntDiv, scal.TrueDiv)
def get_scalar_constant_value(v): def get_scalar_constant_value(orig_v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast If v is the output of dimshuffles, fills, allocs, rebroadcasts, cast
...@@ -544,7 +544,8 @@ def get_scalar_constant_value(v): ...@@ -544,7 +544,8 @@ def get_scalar_constant_value(v):
:note: There may be another function similar to this one in the :note: There may be another function similar to this one in the
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
if True: v = orig_v
while True:
if v is None: if v is None:
# None is not a scalar (and many uses of this function seem to depend # None is not a scalar (and many uses of this function seem to depend
# on passing it None) # on passing it None)
...@@ -567,7 +568,8 @@ def get_scalar_constant_value(v): ...@@ -567,7 +568,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast, if isinstance(v.owner.op, (Alloc, DimShuffle, Rebroadcast,
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0]) v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, theano.compile.ops.Shape_i): elif isinstance(v.owner.op, theano.compile.ops.Shape_i):
if isinstance(v.owner.inputs[0], Constant): if isinstance(v.owner.inputs[0], Constant):
return v.owner.inputs[0].data.shape[v.owner.op.i] return v.owner.inputs[0].data.shape[v.owner.op.i]
...@@ -580,7 +582,8 @@ def get_scalar_constant_value(v): ...@@ -580,7 +582,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
shape, val = v.owner.inputs shape, val = v.owner.inputs
return get_scalar_constant_value(val) v = val
continue
if isinstance(v.owner.op, get_scalar_constant_value_elemwises): if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i) const = [get_scalar_constant_value(i)
for i in v.owner.inputs] for i in v.owner.inputs]
...@@ -591,7 +594,8 @@ def get_scalar_constant_value(v): ...@@ -591,7 +594,8 @@ def get_scalar_constant_value(v):
if isinstance(v.owner.op.scalar_op, scal.Second): if isinstance(v.owner.op.scalar_op, scal.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
shape, val = v.owner.inputs shape, val = v.owner.inputs
return get_scalar_constant_value(val) v = val
continue
elif isinstance(v.owner.op.scalar_op, elif isinstance(v.owner.op.scalar_op,
get_scalar_constant_value_elemwises): get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i) for i in v.owner.inputs] const = [get_scalar_constant_value(i) for i in v.owner.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论