提交 10f86fe3 authored 作者: Frederic Bastien's avatar Frederic Bastien

make the Convolution use all needed compilation option when it use gemm.…

make the Convolution use all needed compilation option when it use gemm. Generate them only when we generate call to gemm.
上级 43947ed9
......@@ -848,9 +848,39 @@ class ConvOp(Op):
using namespace std;
""" + tensor.blas.blas_header_text()
def use_blas(self):
""" Return True if we will generate code that use gemm.
"""
#the gemm version only support that case
if self.out_mode == 'valid' and self.dx==0 and self.dy==0:
#We use a faster version in those case.
if (self.imshp != self.imshp_logical or self.kshp != self.kshp_logical
or self.unroll_patch or self.unroll_batch>0 or self.unroll_kern>0):
return False
return True
return False
def c_libraries(self):
return tensor.blas.ldflags()
if self.use_blas():
return tensor.blas.ldflags()
return []
def c_compile_args(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, flags=True)
return []
def c_lib_dirs(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, libs_dir=True)
return []
def c_header_dirs(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, include_dir=True)
return []
def c_code(self, node, name, (img2d, filtersflipped), (z, ), sub):
if node.inputs[0].type.dtype != node.inputs[1].type.dtype:
raise NotImplementedError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论