提交 09dea9c8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5306 from gvtulder/f-get_scalar_constant_value-for-assert

Try to get_scalar_constant_value for Assert nodes
......@@ -488,6 +488,13 @@ def get_scalar_constant_value(orig_v, elemwise=True,
elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)):
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, theano.tensor.opt.Assert):
# check if all conditions are constant and true
cond = [get_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:]]
if builtins.all([0 == c.ndim and c != 0 for c in cond]):
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second
......
......@@ -7116,6 +7116,28 @@ class T_get_scalar_constant_value(unittest.TestCase):
s = tensor.second(c, .4)
assert numpy.allclose(get_scalar_constant_value(s), .4)
def test_assert(self):
# Make sure we still get the constant value if it is wrapped in
# an Assert.
c = theano.tensor.constant(2)
x = theano.tensor.scalar()
# condition is always True
a = opt.Assert()(c, c > 1)
assert get_scalar_constant_value(a) == 2
# condition is always False
a = opt.Assert()(c, c > 2)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value, a)
# condition is not constant
a = opt.Assert()(c, c > x)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value, a)
def test_second(self):
# Second should apply when the value is constant but not the shape
c = theano.tensor.constant(numpy.random.rand())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论