提交 634dd67e authored 作者: Frederic's avatar Frederic

Small refactoring to speed up get_scalar_constant_value.

This make less condition being tested. It speed up the slow scan tests.
上级 21181c87
...@@ -508,6 +508,12 @@ class EmptyConstantError(NotScalarConstantError): ...@@ -508,6 +508,12 @@ class EmptyConstantError(NotScalarConstantError):
""" """
get_scalar_constant_value_elemwises = (
scal.Cast, scal.Switch,
scal.NEQ, scal.EQ,
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv)
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
...@@ -562,7 +568,7 @@ def get_scalar_constant_value(v): ...@@ -562,7 +568,7 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and elif (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)): isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i] return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
...@@ -570,26 +576,29 @@ def get_scalar_constant_value(v): ...@@ -570,26 +576,29 @@ def get_scalar_constant_value(v):
# mess with the stabilization optimization and be too slow. # mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
if ((isinstance(v.owner.op, Elemwise) and elif isinstance(v.owner.op, scal.ScalarOp):
isinstance(v.owner.op.scalar_op, scal.Second)) or if isinstance(v.owner.op, scal.Second):
isinstance(v.owner.op, scal.Second)): # We don't need both input to be constant for second
# We don't need both input to be constant for second shape, val = v.owner.inputs
shape, val = v.owner.inputs return get_scalar_constant_value(val)
return get_scalar_constant_value(val) if isinstance(v.owner.op, get_scalar_constant_value_elemwises):
elemwises = (scal.Cast, scal.Switch, const = [get_scalar_constant_value(i)
scal.NEQ, scal.EQ, for i in v.owner.inputs]
scal.LT, scal.GT, scal.LE, scal.GE, ret = [[None]]
scal.Sub, scal.Add, scal.Mod, scal.Mul, v.owner.op.perform(v.owner, const, ret)
scal.IntDiv, scal.TrueDiv) return ret[0][0]
if (isinstance(v.owner.op, Elemwise) and elif isinstance(v.owner.op, Elemwise):
len(v.owner.outputs) == 1 and if isinstance(v.owner.op.scalar_op, scal.Second):
(isinstance(v.owner.op.scalar_op, elemwises) or # We don't need both input to be constant for second
isinstance(v.owner.op, elemwises))): shape, val = v.owner.inputs
const = [get_scalar_constant_value(i) for i in v.owner.inputs] return get_scalar_constant_value(val)
ret = [[None]] elif isinstance(v.owner.op.scalar_op,
v.owner.op.perform(v.owner, const, ret) get_scalar_constant_value_elemwises):
return ret[0][0] const = [get_scalar_constant_value(i) for i in v.owner.inputs]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0: ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
elif isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs)) cdata = tuple(v.owner.op.get_constant_idx(v.owner.inputs))
try: try:
...@@ -626,7 +635,7 @@ def get_scalar_constant_value(v): ...@@ -626,7 +635,7 @@ def get_scalar_constant_value(v):
# join can cast implicitly its input in some case. # join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
if (v.owner.inputs[0].owner and elif (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, isinstance(v.owner.inputs[0].owner.op,
theano.tensor.opt.MakeVector) and theano.tensor.opt.MakeVector) and
# MakeVector normally accept only scalar as input. # MakeVector normally accept only scalar as input.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论