提交 ed2cb406 authored 作者: James Bergstra's avatar James Bergstra

In sparse dot c_code, replaced a hard-coded 'double' with a more precise

dtype. This avoids annoying compiler warnings during tests.
上级 960a5469
...@@ -938,6 +938,9 @@ class StructuredDotCSC(gof.Op): ...@@ -938,6 +938,9 @@ class StructuredDotCSC(gof.Op):
"""% dict(locals(), **sub) """% dict(locals(), **sub)
return rval return rval
def c_code_cache_version(self):
return (1,)
sd_csc = StructuredDotCSC() sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op): class StructuredDotCSR(gof.Op):
...@@ -1052,7 +1055,7 @@ class StructuredDotCSR(gof.Op): ...@@ -1052,7 +1055,7 @@ class StructuredDotCSR(gof.Op):
for (npy_int32 k_idx = Dptr[m * Sptr]; k_idx < Dptr[(m+1) * Sptr]; ++k_idx) for (npy_int32 k_idx = Dptr[m * Sptr]; k_idx < Dptr[(m+1) * Sptr]; ++k_idx)
{ {
npy_int32 k = Dind[k_idx * Sind]; // col index of non-null value for row m npy_int32 k = Dind[k_idx * Sind]; // col index of non-null value for row m
const double Amk = Dval[k_idx * Sval]; // actual value at that location const dtype_%(a_val)s Amk = Dval[k_idx * Sval]; // actual value at that location
// get pointer to k-th row of dense matrix // get pointer to k-th row of dense matrix
const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(%(b)s->data + %(b)s->strides[0] * k); const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(%(b)s->data + %(b)s->strides[0] * k);
...@@ -1067,6 +1070,9 @@ class StructuredDotCSR(gof.Op): ...@@ -1067,6 +1070,9 @@ class StructuredDotCSR(gof.Op):
} }
"""% dict(locals(), **sub) """% dict(locals(), **sub)
def c_code_cache_version(self):
return (1,)
sd_csr = StructuredDotCSR() sd_csr = StructuredDotCSR()
# register a specialization to replace StructuredDot -> StructuredDotCSx # register a specialization to replace StructuredDot -> StructuredDotCSx
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论