added more error checking to gemm, changed names from iadd to add_inplace and…

added more error checking to gemm, changed names from iadd to add_inplace and similar for other inplace ops.
上级 6db0342b
......@@ -623,23 +623,23 @@ class NumpyR(gof.PythonR):
def __add__(self, y): return add(self, y)
def __radd__(self, x): return add(x, self)
def __iadd__(self, y): return iadd(self, y)
def __iadd__(self, y): return add_inplace(self, y)
def __sub__(self, y): return sub(self, y)
def __rsub__(self, x): return sub(x, self)
def __isub__(self, y): return isub(self, y)
def __isub__(self, y): return sub_inplace(self, y)
def __mul__(self, y): return mul(self, y)
def __rmul__(self, x): return mul(x, self)
def __imul__(self, y): return imul(self, y)
def __imul__(self, y): return mul_inplace(self, y)
def __div__(self, y): return div(self, y)
def __rdiv__(self, x): return div(x, self)
def __idiv__(self, y): return idiv(self, y)
def __idiv__(self, y): return div_inplace(self, y)
def __pow__(self, y): return pow(self, y)
def __rpow__(self, x): return pow(x, self)
def __ipow__(self, y): return ipow(self, y)
def __ipow__(self, y): return pow_inplace(self, y)
def __neg__(self): return neg(self)
......@@ -704,8 +704,8 @@ class add_elemwise(elemwise):
def c_foreach((x_i, y_i), (z_i, )):
return "z_i = x_i + y_i;"
iadd_elemwise = add_elemwise.inplace_version()
iadd_elemwise.set_impl(assert_same_shapes(numpy.ndarray.__iadd__))
add_elemwise_inplace = add_elemwise.inplace_version()
add_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__iadd__))
class add_scalar(tensor_scalar_op):
......@@ -714,8 +714,8 @@ class add_scalar(tensor_scalar_op):
return gz, sum(gz)
c_expr = "x_i + a"
iadd_scalar = add_scalar.inplace_version()
iadd_scalar.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__))
add_scalar_inplace = add_scalar.inplace_version()
add_scalar_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__iadd__))
class twice(elemwise):
def grad(x, gz):
......@@ -737,8 +737,8 @@ class sub_elemwise(elemwise):
def c_foreach((x_i, y_i), (z_i, )):
return "z_i = x_i - y_i;"
isub_elemwise = sub_elemwise.inplace_version()
isub_elemwise.set_impl(assert_same_shapes(numpy.ndarray.__isub__))
sub_elemwise_inplace = sub_elemwise.inplace_version()
sub_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__isub__))
def sub_scalar_r(x, a):
return add_scalar(x, -a)
......@@ -746,8 +746,8 @@ def sub_scalar_r(x, a):
def sub_scalar_l(x, a):
return add_scalar(-x, a)
def isub_scalar_r(x, a):
return iadd_scalar(x, -a)
def sub_scalar_r_inplace(x, a):
return add_scalar_inplace(x, -a)
## Element-wise multiplication ##
......@@ -759,8 +759,8 @@ class mul_elemwise(elemwise):
def c_foreach((x_i, y_i), (z_i, )):
return "z_i = x_i * y_i;"
imul_elemwise = mul_elemwise.inplace_version()
imul_elemwise.set_impl(assert_same_shapes(numpy.ndarray.__imul__))
mul_elemwise_inplace = mul_elemwise.inplace_version()
mul_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__imul__))
class scale(tensor_scalar_op):
......@@ -769,8 +769,8 @@ class scale(tensor_scalar_op):
return scale(a, gz), sum(mul_elemwise(x, gz))
c_expr = "x_i * a"
iscale = scale.inplace_version()
iscale.set_impl(tensor_scalar_impl(numpy.ndarray.__imul__))
scale_inplace = scale.inplace_version()
scale_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__imul__))
class sqr(elemwise):
......@@ -815,8 +815,8 @@ class div_elemwise(elemwise):
def c_foreach((x_i, y_i), (z_i, )):
return "z_i = x_i / y_i;"
idiv_elemwise = div_elemwise.inplace_version()
idiv_elemwise.set_impl(assert_same_shapes(numpy.ndarray.__idiv__))
div_elemwise_inplace = div_elemwise.inplace_version()
div_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__idiv__))
def div_scalar_r(x, a):
return scale(x, inv_elemwise(a))
......@@ -824,8 +824,8 @@ def div_scalar_r(x, a):
def div_scalar_l(x, a):
return scale(inv_elemwise(x), a)
def idiv_scalar_r(x, a):
return iscale(x, inv_elemwise(a))
def div_scalar_r_inplace(x, a):
return scale_inplace(x, inv_elemwise(a))
......@@ -856,7 +856,7 @@ iinv_elemwise = inv_elemwise.inplace_version()
class blas_code :
@staticmethod
def gemm_xyz(a_init, b_init):
def gemm_xyz(check_ab, a_init, b_init):
mod = '%'
return """
const char * error_string = NULL;
......@@ -880,6 +880,8 @@ class blas_code :
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
%(check_ab)s
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
......@@ -930,6 +932,7 @@ class blas_code :
{
case PyArray_FLOAT:
{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
......@@ -949,10 +952,12 @@ class blas_code :
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
......@@ -971,6 +976,7 @@ class blas_code :
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
#undef REAL
}
break;
}
......@@ -1019,10 +1025,8 @@ class blas_code :
"""
class dot(omega_op):
def impl(x, y):
z = numpy.dot(x,y)
#print z.dtype
return z
impl = numpy.dot
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
......@@ -1030,11 +1034,11 @@ class dot(omega_op):
shape = (x[2][0], y[2][1])
return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_headers(self):
return ["<gsl/gsl_cblas.h>"]
return _blas_headers
def c_libs(self):
return ["cblas", "atlas", "g2c"]
return _blas_libs
def c_impl((_x, _y), (_z, )):
return blas_code.gemm_xyz('1.0', '0.0')
return blas_code.gemm_xyz('', '1.0', '0.0')
class gemm(omega_op, inplace):
......@@ -1070,9 +1074,20 @@ class gemm(omega_op, inplace):
def c_libs(self):
return _blas_libs
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
return blas_code.gemm_xyz(
'((_a->descr->type_num == PyArray_FLOAT) ? (float*)_a->data : (double*)_a->data)[0]',
'((_b->descr->type_num == PyArray_FLOAT) ? (float*)_b->data : (double*)_b->data)[0]')
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
}
"""
return blas_code.gemm_xyz( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
## Transposition ##
......@@ -1186,8 +1201,8 @@ class pow_elemwise(elemwise):
def c_foreach((x_i, s_i), (z_i, )):
return "z_i = pow(x_i, s_i)"
ipow_elemwise = pow_elemwise.inplace_version()
ipow_elemwise.set_impl(assert_same_shapes(numpy.ndarray.__ipow__))
pow_elemwise_inplace = pow_elemwise.inplace_version()
pow_elemwise_inplace.set_impl(assert_same_shapes(numpy.ndarray.__ipow__))
class pow_scalar_l(tensor_scalar_op):
......@@ -1202,8 +1217,8 @@ class pow_scalar_r(tensor_scalar_op):
return gz * s * (pow_scalar_r(x,s-1.0))
c_expr = "pow(x_i, a)"
ipow_scalar_r = pow_scalar_r.inplace_version()
ipow_scalar_r.set_impl(tensor_scalar_impl(numpy.ndarray.__ipow__))
pow_scalar_r_inplace = pow_scalar_r.inplace_version()
pow_scalar_r_inplace.set_impl(tensor_scalar_impl(numpy.ndarray.__ipow__))
......@@ -1273,19 +1288,19 @@ class get_slice(omega_op, view):
add = scalar_switch(add_elemwise, add_scalar, add_scalar)
iadd = scalar_switch(iadd_elemwise, iadd_scalar)
add_inplace = scalar_switch(add_elemwise_inplace, add_scalar_inplace)
sub = scalar_switch(sub_elemwise, sub_scalar_r, sub_scalar_l)
isub = scalar_switch(isub_elemwise, isub_scalar_r)
sub_inplace = scalar_switch(sub_elemwise_inplace, sub_scalar_r_inplace)
mul = scalar_switch(mul_elemwise, scale, scale)
imul = scalar_switch(imul_elemwise, iscale)
mul_inplace = scalar_switch(mul_elemwise_inplace, scale_inplace)
div = scalar_switch(div_elemwise, div_scalar_r, div_scalar_l)
idiv = scalar_switch(idiv_elemwise, idiv_scalar_r)
div_inplace = scalar_switch(div_elemwise_inplace, div_scalar_r_inplace)
pow = scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l)
ipow = scalar_switch(ipow_elemwise, ipow_scalar_r)
pow_inplace = scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论