提交 3b31bfa8 authored 作者: nouiz's avatar nouiz

Merge pull request #1246 from lamblin/fix_check_incsubtensor

Fix check that was too restrictive
......@@ -5052,10 +5052,11 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
dim_offset = x.ndim - y.ndim
if (x.broadcastable[dim + dim_offset]
and not y.broadcastable[dim]):
raise TypeError(("Trying to increment a subtensor with "
"broadcastable dimension %d, with a tensor not broadcastable "
"on corresponding dimension %d.") % (dim + dim_offset, dim),
x.broadcastable, y.broadcastable)
# It is acceptable to try to increment a subtensor with a
# a broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a Rebroadcast Op to make sure it is the case.
y = addbroadcast(y, dim)
if not x.owner:
raise TypeError('x must be result of a subtensor operation')
......
......@@ -68,10 +68,23 @@ class Test_inc_subtensor(unittest.TestCase):
a = tt.col()
increment = tt.vector()
self.assertRaises(TypeError, tt.set_subtensor, a[:], increment)
self.assertRaises(TypeError, tt.set_subtensor, a[0], increment)
self.assertRaises(TypeError, tt.inc_subtensor, a[:], increment)
self.assertRaises(TypeError, tt.inc_subtensor, a[0], increment)
# These symbolic graphs legitimate, as long as increment has exactly
# one element. So it should fail at runtime, not at compile time.
rng = numpy.random.RandomState(utt.fetch_seed())
for op in (tt.set_subtensor, tt.inc_subtensor):
for base in (a[:], a[0]):
out = op(base, increment)
f = theano.function([a, increment], out)
# This one should work
f(rng.rand(3, 1), rng.rand(1))
# These ones should not
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(2))
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(3))
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(0))
def test_simple_3d(self):
"""Increments or sets part of a tensor by a scalar using full slice and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论