提交 3d6798fb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

GpuArray_take1 requires contiguous inputs.

上级 3343d912
...@@ -19,7 +19,7 @@ except ImportError: ...@@ -19,7 +19,7 @@ except ImportError:
from .type import GpuArrayType, gpu_context_type from .type import GpuArrayType, gpu_context_type
from .basic_ops import (as_gpuarray_variable, HideC, GpuKernelBase, Kernel, from .basic_ops import (as_gpuarray_variable, HideC, GpuKernelBase, Kernel,
infer_context_name) infer_context_name, gpu_contiguous)
iadd_reg = {} iadd_reg = {}
...@@ -405,7 +405,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1): ...@@ -405,7 +405,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
""" """
def make_node(self, x, ilist): def make_node(self, x, ilist):
ctx_name = infer_context_name(x, ilist) ctx_name = infer_context_name(x, ilist)
x_ = as_gpuarray_variable(x, ctx_name) x_ = gpu_contiguous(as_gpuarray_variable(x, ctx_name))
ilist__ = tensor.as_tensor_variable(ilist) ilist__ = tensor.as_tensor_variable(ilist)
if ilist__.type.dtype not in tensor.integer_dtypes: if ilist__.type.dtype not in tensor.integer_dtypes:
...@@ -413,7 +413,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1): ...@@ -413,7 +413,7 @@ class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
if ilist__.type.dtype != 'int64': if ilist__.type.dtype != 'int64':
ilist__ = tensor.cast(ilist__, 'int64') ilist__ = tensor.cast(ilist__, 'int64')
ilist_ = as_gpuarray_variable(ilist__, ctx_name) ilist_ = gpu_contiguous(as_gpuarray_variable(ilist__, ctx_name))
if ilist_.type.dtype != 'int64': if ilist_.type.dtype != 'int64':
raise TypeError('index must be int64') raise TypeError('index must be int64')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论