提交 f6d418c2 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Raise TypeError instead of ValueError

to be consistent with AdvancedIncSubtensor1
上级 54358255
...@@ -4732,14 +4732,14 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -4732,14 +4732,14 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
if y.ndim > x.ndim: if y.ndim > x.ndim:
raise ValueError(("Trying to increment a %d-dimensional " raise TypeError(("Trying to increment a %d-dimensional "
"subtensor with a %d-dimensional value.") % (x.ndim, y.ndim)) "subtensor with a %d-dimensional value.") % (x.ndim, y.ndim))
for dim in range(y.ndim): for dim in range(y.ndim):
dim_offset = x.ndim - y.ndim dim_offset = x.ndim - y.ndim
if (x.broadcastable[dim + dim_offset] if (x.broadcastable[dim + dim_offset]
and not y.broadcastable[dim]): and not y.broadcastable[dim]):
raise ValueError(("Trying to increment a subtensor with " raise TypeError(("Trying to increment a subtensor with "
"broadcastable dimension %d, with a tensor not broadcastable " "broadcastable dimension %d, with a tensor not broadcastable "
"on corresponding dimension %d.") % (dim + dim_offset, dim), "on corresponding dimension %d.") % (dim + dim_offset, dim),
x.broadcastable, y.broadcastable) x.broadcastable, y.broadcastable)
......
...@@ -61,17 +61,17 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -61,17 +61,17 @@ class Test_inc_subtensor(unittest.TestCase):
increment = tt.matrix() increment = tt.matrix()
index = 0 index = 0
self.assertRaises(ValueError, tt.set_subtensor, a[index], increment) self.assertRaises(TypeError, tt.set_subtensor, a[index], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[index], increment) self.assertRaises(TypeError, tt.inc_subtensor, a[index], increment)
def test_wrong_broadcast(self): def test_wrong_broadcast(self):
a = tt.col() a = tt.col()
increment = tt.vector() increment = tt.vector()
self.assertRaises(ValueError, tt.set_subtensor, a[:], increment) self.assertRaises(TypeError, tt.set_subtensor, a[:], increment)
self.assertRaises(ValueError, tt.set_subtensor, a[0], increment) self.assertRaises(TypeError, tt.set_subtensor, a[0], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[:], increment) self.assertRaises(TypeError, tt.inc_subtensor, a[:], increment)
self.assertRaises(ValueError, tt.inc_subtensor, a[0], increment) self.assertRaises(TypeError, tt.inc_subtensor, a[0], increment)
def test_simple_3d(self): def test_simple_3d(self):
"""Increments or sets part of a tensor by a scalar using full slice and """Increments or sets part of a tensor by a scalar using full slice and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论