提交 8910e6eb authored 作者: lamblin's avatar lamblin

Merge pull request #1261 from nouiz/adv_sub1

Adv sub1: allow broadcasted index vector.
...@@ -2322,7 +2322,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp): ...@@ -2322,7 +2322,7 @@ class GpuAdvancedSubtensor1(tensor.AdvancedSubtensor1, GpuOp):
ilist_ = tensor.as_tensor_variable(ilist) ilist_ = tensor.as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,): if ilist_.type.ndim != 1:
raise TypeError('index must be vector') raise TypeError('index must be vector')
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
......
...@@ -6896,7 +6896,7 @@ class AdvancedSubtensor1(Op): ...@@ -6896,7 +6896,7 @@ class AdvancedSubtensor1(Op):
ilist_ = as_tensor_variable(ilist) ilist_ = as_tensor_variable(ilist)
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,): if ilist_.type.ndim != 1:
raise TypeError('index must be vector') raise TypeError('index must be vector')
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论