提交 255ed794 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some mistakes pointed out in review.

上级 7cb92b1e
...@@ -140,10 +140,9 @@ class GpuFromHost(Op): ...@@ -140,10 +140,9 @@ class GpuFromHost(Op):
if (%(name)s_tmp == NULL) { if (%(name)s_tmp == NULL) {
%(fail)s %(fail)s
} }
Py_DECREF(%(inp)s);
%(inp)s = %(name)s_tmp;
%(out)s = new_GpuArray((PyObject *)&GpuArrayType); %(out)s = new_GpuArray((PyObject *)&GpuArrayType);
if (%(out)s == NULL) { if (%(out)s == NULL) {
Py_DECREF(%(name)s_tmp);
%(fail)s %(fail)s
} }
%(name)serr = GpuArray_empty(&%(out)s->ga, compyte_get_ops("%(kind)s"), %(name)serr = GpuArray_empty(&%(out)s->ga, compyte_get_ops("%(kind)s"),
...@@ -152,13 +151,16 @@ class GpuFromHost(Op): ...@@ -152,13 +151,16 @@ class GpuFromHost(Op):
(size_t *)PyArray_DIMS(%(inp)s), (size_t *)PyArray_DIMS(%(inp)s),
GA_C_ORDER); GA_C_ORDER);
if (%(name)serr != GA_NO_ERROR) { if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(name)s_tmp);
Py_DECREF(%(out)s); Py_DECREF(%(out)s);
%(fail)s %(fail)s
} }
%(name)serr = GpuArray_write(&%(out)s->ga, PyArray_DATA(%(inp)s), %(name)serr = GpuArray_write(&%(out)s->ga, PyArray_DATA(%(name)s_tmp),
PyArray_NBYTES(%(inp)s)); PyArray_NBYTES(%(name)s_tmp));
Py_DECREF(%(name)s_tmp);
if (%(name)serr != GA_NO_ERROR) { if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(out)s); Py_DECREF(%(out)s);
PyErr_SetString(PyExc_RuntimeError, "Could not copy array data to device");
%(fail)s %(fail)s
} }
""" % {'name': name, 'kind': type.kind, 'ctx': hex(type.context), """ % {'name': name, 'kind': type.kind, 'ctx': hex(type.context),
......
...@@ -77,20 +77,20 @@ class GpuArrayType(Type): ...@@ -77,20 +77,20 @@ class GpuArrayType(Type):
" dimension.", shp, self.broadcastable) " dimension.", shp, self.broadcastable)
return data return data
@classmethod @staticmethod
def values_eq(cls, a, b): def values_eq(a, b):
if a.shape != b.shape: if a.shape != b.shape:
return False return False
if a.typecode != b.typecode: if a.typecode != b.typecode:
return False return False
return numpy.asarray(compare(a, '==', b)).all() return numpy.asarray(compare(a, '==', b)).all()
@classmethod @staticmethod
def values_eq_approx(cls, a, b): def values_eq_approx(a, b):
if a.shape != b.shape or a.dtype != b.dtype: if a.shape != b.shape or a.dtype != b.dtype:
return False return False
if 'int' in str(a.dtype): if 'int' in str(a.dtype):
return cls.values_eq(a, b) return GpuArrayType.values_eq(a, b)
else: else:
res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'), res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'),
op_tmpl="res[i] = ((%(a)s - %(b)s) <" \ op_tmpl="res[i] = ((%(a)s - %(b)s) <" \
...@@ -129,9 +129,11 @@ class GpuArrayType(Type): ...@@ -129,9 +129,11 @@ class GpuArrayType(Type):
return """ return """
%(name)s = NULL; %(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None"); PyErr_SetString(PyExc_ValueError, "expected a GpuArray, not None");
%(fail)s %(fail)s
} }
/* First check if we are the base type exactly (the most common case),
then do the full subclass check if needed. */
if (py_%(name)s->ob_type != &GpuArrayType && if (py_%(name)s->ob_type != &GpuArrayType &&
!PyObject_TypeCheck(py_%(name)s, &GpuArrayType)) { !PyObject_TypeCheck(py_%(name)s, &GpuArrayType)) {
PyErr_SetString(PyExc_ValueError, "expected a GpuArray"); PyErr_SetString(PyExc_ValueError, "expected a GpuArray");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论