提交 f6107d0e authored 作者: Frederic Bastien's avatar Frederic Bastien

have get_constant_value handle the case of Subtensor(Shape) when their is some…

have get_constant_value handle the case of Subtensor(Shape) when their is some dimensions broadcastable. This should fix the bug in the buildbot about the grad of the neibhbors op.
上级 a722d38d
...@@ -404,6 +404,15 @@ def get_constant_value(v): ...@@ -404,6 +404,15 @@ def get_constant_value(v):
ret = get_constant_value(ret) ret = get_constant_value(ret)
#MakeVector can cast implicitly its input in some case. #MakeVector can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
# This is needed when we take the grad as the Shape op
# are not already changed into MakeVector
if (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op,
theano.tensor.Shape)):
if v.owner.inputs[0].owner.inputs[0].type.broadcastable[v.owner.op.idx_list[0]]:
return numpy.asarray(1)
raise TypeError(v) raise TypeError(v)
......
...@@ -3528,6 +3528,10 @@ class T_get_constant_value(unittest.TestCase): ...@@ -3528,6 +3528,10 @@ class T_get_constant_value(unittest.TestCase):
self.assertRaises(TypeError, get_constant_value, a[1]) self.assertRaises(TypeError, get_constant_value, a[1])
self.assertRaises(TypeError, get_constant_value, a[2]) self.assertRaises(TypeError, get_constant_value, a[2])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v = tensor.row()
assert get_constant_value(v.shape[0])==1
if __name__ == '__main__': if __name__ == '__main__':
if 1: if 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论