提交 494a4c6f authored 作者: Frederic's avatar Frederic

get_scalar_constant_value() support more Elewmise op.

Needed to detect the right broadcast pattern.
上级 95938fe8
...@@ -559,26 +559,31 @@ def get_scalar_constant_value(v): ...@@ -559,26 +559,31 @@ def get_scalar_constant_value(v):
compile.ops.OutputGuard, compile.ops.OutputGuard,
compile.DeepCopyOp)): compile.DeepCopyOp)):
return get_scalar_constant_value(v.owner.inputs[0]) return get_scalar_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Elemwise) and \
isinstance(v.owner.op.scalar_op, scal.Second):
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
if isinstance(v.owner.op, scal.Second):
x, y = v.owner.inputs
return get_scalar_constant_value(y)
if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and if (isinstance(v.owner.op, theano.compile.ops.Shape_i) and
isinstance(v.owner.inputs[0], Constant)): isinstance(v.owner.inputs[0], Constant)):
return v.owner.inputs[0].data.shape[v.owner.op.i] return v.owner.inputs[0].data.shape[v.owner.op.i]
# Don't act as the constant_folding optimization here as this # Don't act as the constant_folding optimization here as this
# fct is used too early in the optimization phase. This would # fct is used too early in the optimization phase. This would
# mess with the stabilization optimization. # mess with the stabilization optimization and be too slow.
if (isinstance(v.owner.op, Elemwise) and isinstance( # We put all the scalar Ops used by get_canonical_form_slice()
v.owner.op.scalar_op, scal.Cast)) or \ # to allow it to determine the broadcast pattern correctly.
isinstance(v.owner.op, scal.Cast): elemwises = (scal.Cast, scal.Switch,
const = get_scalar_constant_value(v.owner.inputs[0]) scal.NEQ, scal.EQ,
ret = [[None]] scal.LT, scal.GT, scal.LE, scal.GE,
v.owner.op.perform(v.owner, [const], ret) scal.Sub, scal.Add, scal.Mod, scal.Mul,
return ret[0][0] scal.IntDiv, scal.TrueDiv,
scal.Second)
if (isinstance(v.owner.op, Elemwise) and
len(v.owner.outputs) == 1 and
(isinstance(v.owner.op.scalar_op, elemwises) or
isinstance(v.owner.op, elemwises))):
try:
const = [get_scalar_constant_value(i) for i in v.owner.inputs]
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
except NotScalarConstantError:
pass
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0: if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant # This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply # indices in the Op rather than making them inputs to the Apply
......
...@@ -5928,6 +5928,21 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -5928,6 +5928,21 @@ class T_get_scalar_constant_value(unittest.TestCase):
s = opt.Shape_i(1)(c) s = opt.Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4 assert get_scalar_constant_value(s) == 4
def test_elemwise(self):
# We test only for a few elemwise, the list of all supported
# elemwise are in the fct.
c = theano.tensor.constant(numpy.random.rand())
s = c + 1
assert get_scalar_constant_value(s) == c.data + 1
s = c - 1
assert get_scalar_constant_value(s) == c.data - 1
s = c * 1.2
assert get_scalar_constant_value(s) == c.data * 1.2
s = c < 0.5
assert get_scalar_constant_value(s) == int(c.data < 0.5)
s = tensor.second(c, .4)
assert get_scalar_constant_value(s) == .4
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
...@@ -2408,7 +2408,7 @@ def test_local_subtensor_of_alloc(): ...@@ -2408,7 +2408,7 @@ def test_local_subtensor_of_alloc():
for slices in slicess: for slices in slicess:
z = yx.__getitem__(slices) z = yx.__getitem__(slices)
f = theano.function([x], z) f = theano.function([x], z)
theano.printing.debugprint(f) # theano.printing.debugprint(f)
# if theano.config.mode != 'FAST_COMPILE': # if theano.config.mode != 'FAST_COMPILE':
# assert not any([isinstance(node.op, Subtensor) # assert not any([isinstance(node.op, Subtensor)
# for node in f.maker.fgraph.toposort()]) # for node in f.maker.fgraph.toposort()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论