提交 4a7a9e72 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove a template argument from batch_gemm

This might in some case where the compiler got confused and could not infer the right type for the gemm function.
上级 96c91044
......@@ -2054,9 +2054,10 @@ class BatchedDot(Op):
def c_support_code(self):
batch_gemm_defn = """
template<typename dtype, typename function>
bool batch_gemm(function gemm, int type_size,
PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) {
template<typename dtype>
bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*),
int type_size, PyArrayObject* xs, PyArrayObject* ys,
PyArrayObject* zs) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
......@@ -2323,7 +2324,7 @@ class BatchedDot(Op):
def c_code_cache_version(self):
from theano.tensor.blas_headers import blas_header_version
return (3, blas_header_version())
return (4, blas_header_version())
def grad(self, inp, grads):
x, y = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论