提交 02e1e3f8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4621 from abergeron/incsub_0

Add support for no indices in advincsub1 for gpuarrray.
...@@ -673,6 +673,15 @@ def local_gpua_advanced_incsubtensor(node, context_name): ...@@ -673,6 +673,15 @@ def local_gpua_advanced_incsubtensor(node, context_name):
set_instead_of_inc=set_instead_of_inc) set_instead_of_inc=set_instead_of_inc)
@register_inplace()
@local_optimizer([GpuAdvancedIncSubtensor1, GpuAdvancedIncSubtensor1_dev20])
def local_advincsub1_gpua_inplace(node):
if isinstance(node.op, (GpuAdvancedIncSubtensor1,
GpuAdvancedIncSubtensor1_dev20)):
if not node.op.inplace:
return [node.op.clone_inplace()(*node.inputs)]
@register_opt('fast_compile') @register_opt('fast_compile')
@op_lifter([tensor.CAReduce, tensor.Sum, tensor.elemwise.Prod]) @op_lifter([tensor.CAReduce, tensor.Sum, tensor.elemwise.Prod])
def local_gpua_careduce(node, context_name): def local_gpua_careduce(node, context_name):
......
...@@ -608,11 +608,6 @@ class GpuAdvancedIncSubtensor1(Op): ...@@ -608,11 +608,6 @@ class GpuAdvancedIncSubtensor1(Op):
} }
step[0] = 0; step[0] = 0;
num_indices = PyArray_SIZE(%(ind)s); num_indices = PyArray_SIZE(%(ind)s);
if ((num_indices - 1) > LONG_MAX) {
PyErr_Format(PyExc_AssertionError,
"num_indices %%lld exceeds LONG_MAX + 1", (long long)num_indices);
%(fail)s
}
if (!%(inplace)s) { if (!%(inplace)s) {
%(out)s = theano_try_copy(%(out)s, %(x)s); %(out)s = theano_try_copy(%(out)s, %(x)s);
if (%(out)s == NULL) if (%(out)s == NULL)
...@@ -622,6 +617,12 @@ class GpuAdvancedIncSubtensor1(Op): ...@@ -622,6 +617,12 @@ class GpuAdvancedIncSubtensor1(Op):
%(out)s = %(x)s; %(out)s = %(x)s;
Py_INCREF(%(out)s); Py_INCREF(%(out)s);
} }
if (num_indices != 0) {
if ((num_indices - 1) > LONG_MAX) {
PyErr_Format(PyExc_AssertionError,
"num_indices %%lld exceeds LONG_MAX + 1", (long long)num_indices);
%(fail)s
}
broadcast_y = PyGpuArray_DIM(%(y)s, 0) == 1; broadcast_y = PyGpuArray_DIM(%(y)s, 0) == 1;
for (j = 0; j < num_indices; j++) { for (j = 0; j < num_indices; j++) {
start[0] = *(dtype_%(ind)s *)PyArray_GETPTR1(%(ind)s, j); start[0] = *(dtype_%(ind)s *)PyArray_GETPTR1(%(ind)s, j);
...@@ -659,13 +660,14 @@ class GpuAdvancedIncSubtensor1(Op): ...@@ -659,13 +660,14 @@ class GpuAdvancedIncSubtensor1(Op):
if (ret != GA_NO_ERROR) if (ret != GA_NO_ERROR)
PyErr_SetString(PyExc_RuntimeError, "Failed to set/inc elements"); PyErr_SetString(PyExc_RuntimeError, "Failed to set/inc elements");
} }
}
""" % dict(x=inputs[0], y=inputs[1], ind=inputs[2], out=outputs[0], """ % dict(x=inputs[0], y=inputs[1], ind=inputs[2], out=outputs[0],
fail=sub['fail'], inplace=int(self.inplace), fail=sub['fail'], inplace=int(self.inplace),
nd=node.inputs[0].ndim, nd=node.inputs[0].ndim,
set_instead_of_inc=int(self.set_instead_of_inc)) set_instead_of_inc=int(self.set_instead_of_inc))
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (1,)
class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, HideC, class GpuAdvancedIncSubtensor1_dev20(GpuKernelBase, HideC,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论