提交 634dd67e authored 作者: Frederic's avatar Frederic

Small refactoring to speed up get_scalar_constant_value.

This make less condition being tested. It speed up the slow scan tests.
上级 21181c87
......@@ -508,6 +508,12 @@ class EmptyConstantError(NotScalarConstantError):
"""
get_scalar_constant_value_elemwises = (
scal.Cast, scal.Switch,
scal.NEQ, scal.EQ,
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv)
def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
......@@ -562,7 +568,7 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard,
compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0])
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
elif (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this
......@@ -570,26 +576,29 @@ def get_scalar_constant_value(v):
# mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly.
if ((isinstance(v.owner.op, Elemwise) and
isinstance(v.owner.op.scalar_op, scal.Second)) or
isinstance(v.owner.op, scal.Second)):
elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
elemwises = (scal.Cast, scal.Switch,
scal.NEQ, scal.EQ,
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv)
if (isinstance(v.owner.op, Elemwise) and
len(v.owner.outputs) == 1 and
(isinstance(v.owner.op.scalar_op, elemwises) or
isinstance(v.owner.op, elemwises))):
if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i)
for i in v.owner.inputs]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
elif isinstance(v.owner.op, Elemwise):
if isinstance(v.owner.op.scalar_op, scal.Second):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
elif isinstance(v.owner.op.scalar_op,
get_scalar_constant_value_elemwises):
const = [get_scalar_constant_value(i) for i in v.owner.inputs]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
elif isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant):
cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
try:
......@@ -626,7 +635,7 @@ def get_scalar_constant_value(v):
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and
elif (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op,
theano.tensor.opt.MakeVector) and
# MakeVector normally accept only scalar as input.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论