提交 494a4c6f authored 作者: Frederic's avatar Frederic

get_scalar_constant_value() support more Elewmise op.

Needed to detect the right broadcast pattern.
上级 95938fe8
......@@ -559,26 +559,31 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard,
compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Elemwise) and \
isinstance(v.owner.op.scalar_op, scal.Second):
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs
return get_scalar_constant_value(y)
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would
# mess with the stabilization optimization.
if (isinstance(v.owner.op, Elemwise) and isinstance(
v.owner.op.scalar_op, scal.Cast)) or \
isinstance(v.owner.op, scal.Cast):
const = get_scalar_constant_value(v.owner.inputs[0])
ret = [[None]]
v.owner.op.perform(v.owner, [const], ret)
return ret[0][0]
# mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly.
elemwises = (scal.Cast, scal.Switch,
scal.NEQ, scal.EQ,
scal.LT, scal.GT, scal.LE, scal.GE,
scal.Sub, scal.Add, scal.Mod, scal.Mul,
scal.IntDiv, scal.TrueDiv,
scal.Second)
if (isinstance(v.owner.op, Elemwise) and
len(v.owner.outputs) == 1 and
(isinstance(v.owner.op.scalar_op, elemwises) or
isinstance(v.owner.op, elemwises))):
try:
const = [get_scalar_constant_value(i) for i in v.owner.inputs]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
except NotScalarConstantError:
pass
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply
......
......@@ -5928,6 +5928,21 @@ class T_get_scalar_constant_value(unittest.TestCase):
s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4
def test_elemwise(self):
# We test only for a few elemwise, the list of all supported
# elemwise are in the fct.
c = theano.tensor.constant(numpy.random.rand())
s = c + 1
assert get_scalar_constant_value(s) == c.data + 1
s = c - 1
assert get_scalar_constant_value(s) == c.data - 1
s = c * 1.2
assert get_scalar_constant_value(s) == c.data * 1.2
s = c < 0.5
assert get_scalar_constant_value(s) == int(c.data < 0.5)
s = tensor.second(c, .4)
assert get_scalar_constant_value(s) == .4
class T_as_tensor_variable(unittest.TestCase):
"""
......
......@@ -2408,7 +2408,7 @@ def test_local_subtensor_of_alloc():
for slices in slicess:
z = yx.__getitem__(slices)
f = theano.function([x], z)
theano.printing.debugprint(f)
# theano.printing.debugprint(f)
# if theano.config.mode != 'FAST_COMPILE':
# assert not any([isinstance(node.op, Subtensor)
# for node in f.maker.fgraph.toposort()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论