提交 54358255 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Checks the dimensions and bcast in inc_subtensor

上级 5a421ca3
...@@ -4727,6 +4727,26 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -4727,6 +4727,26 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
>>> new_r = inc_subtensor(r[10:], 5) >>> new_r = inc_subtensor(r[10:], 5)
""" """
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise ValueError(("Trying to increment a %d-dimensional "
"subtensor with a %d-dimensional value.") % (x.ndim, y.ndim))
for dim in range(y.ndim):
dim_offset = x.ndim - y.ndim
if (x.broadcastable[dim + dim_offset]
and not y.broadcastable[dim]):
raise ValueError(("Trying to increment a subtensor with "
"broadcastable dimension %d, with a tensor not broadcastable "
"on corresponding dimension %d.") % (dim + dim_offset, dim),
x.broadcastable, y.broadcastable)
if not x.owner:
raise TypeError('x must be result of a subtensor operation')
# retrieve idx_list from x.owner # retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor): if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing: if tolerate_inplace_aliasing:
......
...@@ -56,6 +56,23 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -56,6 +56,23 @@ class Test_inc_subtensor(unittest.TestCase):
self.assertTrue(numpy.array_equal(result, expected_result)) self.assertTrue(numpy.array_equal(result, expected_result))
def test_wrong_dims(self):
a = tt.matrix()
increment = tt.matrix()
index = 0
self.assertRaises(ValueError, tt.set_subtensor, a[index], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[index], increment)
def test_wrong_broadcast(self):
a = tt.col()
increment = tt.vector()
self.assertRaises(ValueError, tt.set_subtensor, a[:], increment)
self.assertRaises(ValueError, tt.set_subtensor, a[0], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[:], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[0], increment)
def test_simple_3d(self): def test_simple_3d(self):
"""Increments or sets part of a tensor by a scalar using full slice and """Increments or sets part of a tensor by a scalar using full slice and
a partial slice depending on a scalar. a partial slice depending on a scalar.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论