提交 4ccc4c1a authored 作者: Frederic's avatar Frederic

Allow AdvancedSubtensor1 support broadcastable vector for the indices.

上级 1617a923
......@@ -2322,7 +2322,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
ilist_ = tensor.as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,):
if ilist_.type.ndim != 1:
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
......
......@@ -6896,7 +6896,7 @@ class AdvancedSubtensor1(Op):
ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,):
if ilist_.type.ndim != 1:
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
......
......@@ -3168,6 +3168,23 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.assertTrue(numpy.allclose(f([0]), ones[0] * 5))
self.assertRaises(IndexError, f, [0, 1])
def test_adv_sub1_idx_broadcast(self):
# The idx can be a broadcastable vector.
ones = numpy.ones((4, 3), dtype=self.dtype)
n = self.shared(ones * 5)
idx = tensor.TensorType(dtype='int64', broadcastable=(True,))()
assert idx.type.broadcastable == (True,)
t = n[idx]
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1))
f = self.function([idx], t, op=self.adv_sub1)
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)]
assert len(topo_) == 1
self.assertTrue(isinstance(topo_[0].op, self.adv_sub1))
self.assertTrue(numpy.allclose(f([0]), ones[0] * 5))
def test_shape_i_const(self):
# Each axis is treated independently by shape_i/shape operators
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论