提交 66f23bc6 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made copying the input work on both CPU and GPU

上级 3bd3ae11
...@@ -2419,6 +2419,9 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2419,6 +2419,9 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs) rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
return Apply(self, [x, y] + rval.inputs[2:], [x.type()]) return Apply(self, [x, y] + rval.inputs[2:], [x.type()])
def copy_of_x(self, x):
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
class GpuFlatten(tensor.Flatten, GpuOp): class GpuFlatten(tensor.Flatten, GpuOp):
""" """
......
...@@ -4554,6 +4554,9 @@ class IncSubtensor(Op): ...@@ -4554,6 +4554,9 @@ class IncSubtensor(Op):
view_ndim = (node.inputs[0].ndim - view_ndim = (node.inputs[0].ndim -
numpy.sum([not isinstance(idx, slice) numpy.sum([not isinstance(idx, slice)
for idx in self.idx_list])) for idx in self.idx_list]))
copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = """ copy_input_if_necessary = """
if (%(inplace)s) if (%(inplace)s)
{ {
...@@ -4567,9 +4570,7 @@ class IncSubtensor(Op): ...@@ -4567,9 +4570,7 @@ class IncSubtensor(Op):
else else
{ {
if (%(z)s) Py_DECREF(%(z)s); if (%(z)s) Py_DECREF(%(z)s);
%(z)s = (PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0, %(z)s = %(copy_of_x)s; }
NPY_ARRAY_ENSURECOPY, NULL);
}
""" % locals() """ % locals()
#Make a first view on the output, as we will write into it. #Make a first view on the output, as we will write into it.
...@@ -4637,6 +4638,27 @@ class IncSubtensor(Op): ...@@ -4637,6 +4638,27 @@ class IncSubtensor(Op):
else: else:
return () return ()
def copy_of_x(self, x):
"""
x: a string giving the name of a C variable pointing to an array
Returns C code expression to make a copy of x.
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
# Parameters of PyArrary_FromAny are:
# array
# dtype: we pass NULL to say any dtype is acceptable, so the existing
# dtype will be copied
# min_depth: we pass 0 to have this parameter ignored
# max_depth: we pass 0 to have this parameter ignored
# requirements: here we pass NPY_ARRAY_ENSURECOPY to force a copy
# context: this is almost always NULL, I'm not sure what it's used for
return """(PyArrayObject*)PyArray_FromAny(py_%(x)s, NULL, 0, 0,
NPY_ARRAY_ENSURECOPY, NULL)""" % locals()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论