提交 f1f3e8fb authored 作者: Frederic's avatar Frederic

Finish GpuIncSubtensor c code.

上级 f251db8a
......@@ -2794,7 +2794,7 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
"""
return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals()
def add_to_zview(self, x, fail):
def add_to_zview(self, name, x, fail):
return """
PyObject * add_result = CudaNdarray_inplace_add((PyObject *) zview,
......
import copy
import StringIO
import numpy
import theano
from theano import tensor, gof
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
from theano.gof.python25 import all, any
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
try:
import pygpu
......@@ -16,6 +17,7 @@ except ImportError:
from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import as_gpuarray_variable, HideC
from theano.sandbox.gpuarray.elemwise import GpuElemwise
class GpuSubtensor(HideC, Subtensor):
......@@ -156,7 +158,7 @@ class GpuSubtensor(HideC, Subtensor):
return (5,)
class GpuIncSubtensor(HideC, IncSubtensor):
class GpuIncSubtensor(IncSubtensor):
"""
Implement IncSubtensor on the gpu.
......@@ -166,15 +168,29 @@ class GpuIncSubtensor(HideC, IncSubtensor):
The helper methods like do_type_checking, copy_of_x, etc. specialize
the c_code for this Op.
"""
def c_headers(self):
return ['<compyte/numpy_compat.h>']
return self.iadd_node.op.c_headers()
def c_compiler(self):
return self.iadd_node.op.c_compiler()
def c_init_code(self):
return self.iadd_node.op.c_init_code()
def make_node(self, x, y, *inputs):
x = as_gpuarray_variable(x)
y = as_gpuarray_variable(y)
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
return gof.Apply(self, [x, y] + rval.inputs[2:], [x.type()])
# We store a iadd_node in the op that contain the info needed
# for the inplace add.
cop = theano.tensor.inplace.add_inplace
gop = GpuElemwise(cop.scalar_op, copy.copy(cop.inplace_pattern),
"Gpu" + cop.name, cop.nfunc_spec)
xview = y.type()
iadd_node = gop(xview, y).owner
op = copy.copy(self)
op.iadd_node = iadd_node
return gof.Apply(op, [x, y] + rval.inputs[2:], [x.type()])
def perform(self, node, inputs, out_):
out, = out_
......@@ -281,11 +297,30 @@ class GpuIncSubtensor(HideC, IncSubtensor):
"""
return """GpuArray_move(&%(view)s->ga, &%(source)s->ga)""" % locals()
def add_to_zview(self, x, fail):
def c_support_code_apply(self, node, nodename):
gop = self.iadd_node.op
sub_name = nodename + "_add_to_zview"
ret = gop.c_support_code_apply(self.iadd_node, sub_name)
ret += """
PyGpuArrayObject* inc_sub_iadd_%(nodename)s(PyGpuArrayObject* dst,
PyGpuArrayObject* src){
PyGpuArrayObject* ret = NULL;
""" % locals()
#def c_code(self, node, name, inputs, outputs, sub):
inputs = ["dst", "src"]
outputs = ["ret"]
sub = {"fail": "return NULL;"}
ret += gop.c_code(self.iadd_node, sub_name, inputs, outputs, sub)
ret += """
return dst;
}
"""
return ret
def add_to_zview(self, nodename, x, fail):
#TODO
return """
PyObject * add_result = CudaNdarray_inplace_add((PyObject *) zview,
(PyObject *) py_%(x)s);
PyGpuArrayObject * add_result = inc_sub_iadd_%(nodename)s(zview, %(x)s);
if (! add_result )
{
......@@ -299,8 +334,8 @@ class GpuIncSubtensor(HideC, IncSubtensor):
""" % locals()
def c_code_cache_version(self):
return ()
parent_version = super(GpuIncSubtensor, self).c_code_cache_version()
if parent_version:
return parent_version + (0,)
return ()
elemwise_version = self.iadd_node.c_code_cache_version()
if not parent_version or not elemwise_version:
return
return parent_version + elemwise_version + (0,)
......@@ -1232,7 +1232,7 @@ class IncSubtensor(Op):
copy_into = self.copy_into("zview", y)
add_to_zview = self.add_to_zview(y, fail)
add_to_zview = self.add_to_zview(name, y, fail)
make_modification = """
if (%(op_is_set)s)
......@@ -1330,7 +1330,7 @@ class IncSubtensor(Op):
"""
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def add_to_zview(self, x, fail):
def add_to_zview(self, name, x, fail):
""" Return C code to add x to zview. Should DECREF zview if the
add fails."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论