提交 4a7a9e72 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove a template argument from batch_gemm

This might in some case where the compiler got confused and could not infer the right type for the gemm function.
上级 96c91044
...@@ -2054,9 +2054,10 @@ class BatchedDot(Op): ...@@ -2054,9 +2054,10 @@ class BatchedDot(Op):
def c_support_code(self): def c_support_code(self):
batch_gemm_defn = """ batch_gemm_defn = """
template<typename dtype, typename function> template<typename dtype>
bool batch_gemm(function gemm, int type_size, bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*),
PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) { int type_size, PyArrayObject* xs, PyArrayObject* ys,
PyArrayObject* zs) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs); npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys); npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs); npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
...@@ -2323,7 +2324,7 @@ class BatchedDot(Op): ...@@ -2323,7 +2324,7 @@ class BatchedDot(Op):
def c_code_cache_version(self): def c_code_cache_version(self):
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
return (3, blas_header_version()) return (4, blas_header_version())
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论