提交 255ed794 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some mistakes pointed out in review.

上级 7cb92b1e
......@@ -140,10 +140,9 @@ class GpuFromHost(Op):
if (%(name)s_tmp == NULL) {
%(fail)s
}
Py_DECREF(%(inp)s);
%(inp)s = %(name)s_tmp;
%(out)s = new_GpuArray((PyObject *)&GpuArrayType);
if (%(out)s == NULL) {
Py_DECREF(%(name)s_tmp);
%(fail)s
}
%(name)serr = GpuArray_empty(&%(out)s->ga, compyte_get_ops("%(kind)s"),
......@@ -152,13 +151,16 @@ class GpuFromHost(Op):
(size_t *)PyArray_DIMS(%(inp)s),
GA_C_ORDER);
if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(name)s_tmp);
Py_DECREF(%(out)s);
%(fail)s
}
%(name)serr = GpuArray_write(&%(out)s->ga, PyArray_DATA(%(inp)s),
PyArray_NBYTES(%(inp)s));
%(name)serr = GpuArray_write(&%(out)s->ga, PyArray_DATA(%(name)s_tmp),
PyArray_NBYTES(%(name)s_tmp));
Py_DECREF(%(name)s_tmp);
if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(out)s);
PyErr_SetString(PyExc_RuntimeError, "Could not copy array data to device");
%(fail)s
}
""" % {'name': name, 'kind': type.kind, 'ctx': hex(type.context),
......
......@@ -77,20 +77,20 @@ class GpuArrayType(Type):
" dimension.", shp, self.broadcastable)
return data
@classmethod
def values_eq(cls, a, b):
@staticmethod
def values_eq(a, b):
if a.shape != b.shape:
return False
if a.typecode != b.typecode:
return False
return numpy.asarray(compare(a, '==', b)).all()
@classmethod
def values_eq_approx(cls, a, b):
@staticmethod
def values_eq_approx(a, b):
if a.shape != b.shape or a.dtype != b.dtype:
return False
if 'int' in str(a.dtype):
return cls.values_eq(a, b)
return GpuArrayType.values_eq(a, b)
else:
res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'),
op_tmpl="res[i] = ((%(a)s - %(b)s) <" \
......@@ -129,9 +129,11 @@ class GpuArrayType(Type):
return """
%(name)s = NULL;
if (py_%(name)s == Py_None) {
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
PyErr_SetString(PyExc_ValueError, "expected a GpuArray, not None");
%(fail)s
}
/* First check if we are the base type exactly (the most common case),
then do the full subclass check if needed. */
if (py_%(name)s->ob_type != &GpuArrayType &&
!PyObject_TypeCheck(py_%(name)s, &GpuArrayType)) {
PyErr_SetString(PyExc_ValueError, "expected a GpuArray");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论