提交 6e0f1fe4 authored 作者: Frederic's avatar Frederic

Make GpuAdvancedIncSubtensor1 work for set_subtensor.

上级 3de42af7
...@@ -1944,6 +1944,36 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp): ...@@ -1944,6 +1944,36 @@ class GpuAdvancedIncSubtensor1(tensor.AdvancedIncSubtensor1, GpuOp):
#def perform(self, node, inp, out_): #def perform(self, node, inp, out_):
# CudaNdarray_Subscript() don't support Advanced slicing. # CudaNdarray_Subscript() don't support Advanced slicing.
# so we use the parent version that loop on each indices. # so we use the parent version that loop on each indices.
def perform(self, node, inp, out_):
# TODO opt to make this inplace
x, y, idx = inp
out, = out_
if not self.inplace:
x = x.copy()
if self.set_instead_of_inc:
# CudaNdarray __setitem__ don't do broadcast nor support
# list of index.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim:
assert len(y) == len(idx)
for (j, i) in enumerate(idx):
x[i] = y[j]
else:
for i in idx:
x[i] = y
else:
# If `y` has as many dimensions as `x`, then we want to iterate
# jointly on `x` and `y`. Otherwise, it means `y` should be
# broadcasted to fill all relevant rows of `x`.
assert y.ndim <= x.ndim # Should be guaranteed by `make_node`
if y.ndim == x.ndim:
assert len(y) == len(idx)
for (j, i) in enumerate(idx):
x[i] += y[j]
else:
for i in idx:
x[i] += y
out[0] = x
class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
......
...@@ -765,8 +765,6 @@ def local_gpu_advanced_incsubtensor1(node): ...@@ -765,8 +765,6 @@ def local_gpu_advanced_incsubtensor1(node):
'either set the `warn.gpu_set_subtensor1` config ' 'either set the `warn.gpu_set_subtensor1` config '
'option to False, or `warn.ignore_bug_before` to at ' 'option to False, or `warn.ignore_bug_before` to at '
'least \'0.6\'.', stacklevel=1) 'least \'0.6\'.', stacklevel=1)
if set_instead_of_inc:
return
gpu_op = GpuAdvancedIncSubtensor1( gpu_op = GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
...@@ -799,8 +797,7 @@ def local_gpu_advanced_incsubtensor1(node): ...@@ -799,8 +797,7 @@ def local_gpu_advanced_incsubtensor1(node):
'either set the `warn.gpu_set_subtensor1` config ' 'either set the `warn.gpu_set_subtensor1` config '
'option to False, or `warn.ignore_bug_before` to at ' 'option to False, or `warn.ignore_bug_before` to at '
'least \'0.6\'.', stacklevel=1) 'least \'0.6\'.', stacklevel=1)
if set_instead_of_inc:
return
gpu_op = GpuAdvancedIncSubtensor1( gpu_op = GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
return [host_from_gpu(gpu_op(gpu_x, gpu_y, *coords))] return [host_from_gpu(gpu_op(gpu_x, gpu_y, *coords))]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论