提交 906bbb75 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix check that was too restrictive

上级 39cf0c45
...@@ -5052,10 +5052,11 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5052,10 +5052,11 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
dim_offset = x.ndim - y.ndim dim_offset = x.ndim - y.ndim
if (x.broadcastable[dim + dim_offset] if (x.broadcastable[dim + dim_offset]
and not y.broadcastable[dim]): and not y.broadcastable[dim]):
raise TypeError(("Trying to increment a subtensor with " # It is acceptable to try to increment a subtensor with a
"broadcastable dimension %d, with a tensor not broadcastable " # a broadcastable dim with a tensor that is not broadcastable
"on corresponding dimension %d.") % (dim + dim_offset, dim), # on that dimension. However, its length must then be 1.
x.broadcastable, y.broadcastable) # We insert a Rebroadcast Op to make sure it is the case.
y = addbroadcast(y, dim)
if not x.owner: if not x.owner:
raise TypeError('x must be result of a subtensor operation') raise TypeError('x must be result of a subtensor operation')
......
...@@ -68,10 +68,23 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -68,10 +68,23 @@ class Test_inc_subtensor(unittest.TestCase):
a = tt.col() a = tt.col()
increment = tt.vector() increment = tt.vector()
self.assertRaises(TypeError, tt.set_subtensor, a[:], increment) # These symbolic graphs legitimate, as long as increment has exactly
self.assertRaises(TypeError, tt.set_subtensor, a[0], increment) # one element. So it should fail at runtime, not at compile time.
self.assertRaises(TypeError, tt.inc_subtensor, a[:], increment) rng = numpy.random.RandomState(utt.fetch_seed())
self.assertRaises(TypeError, tt.inc_subtensor, a[0], increment)
for op in (tt.set_subtensor, tt.inc_subtensor):
for base in (a[:], a[0]):
out = op(base, increment)
f = theano.function([a, increment], out)
# This one should work
f(rng.rand(3, 1), rng.rand(1))
# These ones should not
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(2))
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(3))
self.assertRaises(ValueError,
f, rng.rand(3, 1), rng.rand(0))
def test_simple_3d(self): def test_simple_3d(self):
"""Increments or sets part of a tensor by a scalar using full slice and """Increments or sets part of a tensor by a scalar using full slice and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论