提交 a711ef41 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the checks on the GPU versions of the ops (again).

上级 395c211e
...@@ -2512,7 +2512,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -2512,7 +2512,7 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,): if ilist_.type.ndim != 1:
raise TypeError('index must be vector') raise TypeError('index must be vector')
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
...@@ -2681,15 +2681,15 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1): ...@@ -2681,15 +2681,15 @@ class GpuAdvancedIncSubtensor1_dev20(GpuAdvancedIncSubtensor1):
if ilist_.type.dtype[:3] not in ('int', 'uin'): if ilist_.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('index must be integers') raise TypeError('index must be integers')
if ilist_.type.broadcastable != (False,): if ilist_.type.ndim != 1:
raise TypeError('index must be vector') raise TypeError('index must be vector')
if x_.type.ndim == 0: if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar') raise TypeError('cannot index into a scalar')
if x_.type.broadcastable[0]:
# the caller should have made a copy of x len(ilist) times
raise TypeError('cannot index into a broadcastable dimension')
return Apply(self, [x_, y_, ilist_], [x_.type()]) bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
return Apply(self, [x_, y_, ilist_],
[CudaNdarrayType(dtype=x_.dtype,
broadcastable=bcast)()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论