提交 2ca0426d authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Try to get_scalar_constant_value for Assert nodes.

上级 89849eb7
...@@ -488,6 +488,13 @@ def get_scalar_constant_value(orig_v, elemwise=True, ...@@ -488,6 +488,13 @@ def get_scalar_constant_value(orig_v, elemwise=True,
elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)): elif isinstance(v.owner.op, (ScalarFromTensor, TensorFromScalar)):
v = v.owner.inputs[0] v = v.owner.inputs[0]
continue continue
elif isinstance(v.owner.op, theano.tensor.opt.Assert):
# check if all conditions are constant and true
cond = [get_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:]]
if builtins.all([0 == c.ndim and c != 0 for c in cond]):
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, scal.ScalarOp): elif isinstance(v.owner.op, scal.ScalarOp):
if isinstance(v.owner.op, scal.Second): if isinstance(v.owner.op, scal.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
......
...@@ -7101,6 +7101,28 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -7101,6 +7101,28 @@ class T_get_scalar_constant_value(unittest.TestCase):
s = tensor.second(c, .4) s = tensor.second(c, .4)
assert numpy.allclose(get_scalar_constant_value(s), .4) assert numpy.allclose(get_scalar_constant_value(s), .4)
def test_assert(self):
# Make sure we still get the constant value if it is wrapped in
# an Assert.
c = theano.tensor.constant(2)
x = theano.tensor.scalar()
# condition is always True
a = opt.Assert()(c, c > 1)
assert get_scalar_constant_value(a) == 2
# condition is always False
a = opt.Assert()(c, c > 2)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value, a)
# condition is not constant
a = opt.Assert()(c, c > x)
self.assertRaises(
tensor.NotScalarConstantError,
get_scalar_constant_value, a)
def test_second(self): def test_second(self):
# Second should apply when the value is constant but not the shape # Second should apply when the value is constant but not the shape
c = theano.tensor.constant(numpy.random.rand()) c = theano.tensor.constant(numpy.random.rand())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论