提交 58741fb3 authored 作者: Frederic's avatar Frederic

re-enable get_scalar_constant_value() on Second when only the value is constant.

上级 1e480107
...@@ -567,6 +567,12 @@ def get_scalar_constant_value(v): ...@@ -567,6 +567,12 @@ def get_scalar_constant_value(v):
# mess with the stabilization optimization and be too slow. # mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
if ((isinstance(v.owner.op, Elemwise) and
isinstance(v.owner.op.scalar_op, scal.Second)) or
isinstance(v.owner.op, scal.Second)):
# We don't need both input to be constant for second
shape, val = v.owner.inputs
return get_scalar_constant_value(val)
elemwises = (scal.Cast, scal.Switch, elemwises = (scal.Cast, scal.Switch,
scal.NEQ, scal.EQ, scal.NEQ, scal.EQ,
scal.LT, scal.GT, scal.LE, scal.GE, scal.LT, scal.GT, scal.LE, scal.GE,
......
...@@ -5943,6 +5943,13 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -5943,6 +5943,13 @@ class T_get_scalar_constant_value(unittest.TestCase):
s = tensor.second(c, .4) s = tensor.second(c, .4)
assert get_scalar_constant_value(s) == .4 assert get_scalar_constant_value(s) == .4
def test_second(self):
#Second should apply when the value is constant but not the shape
c = theano.tensor.constant(numpy.random.rand())
shp = theano.tensor.vector()
s = theano.tensor.second(shp, c)
assert get_scalar_constant_value(s) == c.data
class T_as_tensor_variable(unittest.TestCase): class T_as_tensor_variable(unittest.TestCase):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论