提交 c55401a0 authored 作者: James Bergstra's avatar James Bergstra

Added code to tensor.blas to prevent emitting c code when dtype is complex

上级 4d040292
...@@ -589,6 +589,9 @@ class Gemm(GemmRelated): ...@@ -589,6 +589,9 @@ class Gemm(GemmRelated):
""" """
def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG def c_code(self, node, name, (_z, _a, _x, _y, _b), (_zout, ), sub): #DEBUG
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if not config.blas.ldflags: if not config.blas.ldflags:
return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub) return super(Gemm, self).c_code(node, name, (_z, _a, _x, _y, _b), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
...@@ -978,6 +981,9 @@ class Dot22(GemmRelated): ...@@ -978,6 +981,9 @@ class Dot22(GemmRelated):
double b = 0.0; double b = 0.0;
""" """
def c_code(self, node, name, (_x, _y), (_zout, ), sub): #DEBUG def c_code(self, node, name, (_x, _y), (_zout, ), sub): #DEBUG
if node.inputs[0].type.dtype.startswith('complex'):
raise utils.MethodNotDefined('%s.c_code' \
% self.__class__.__name__)
if len(self.c_libraries())<=0: if len(self.c_libraries())<=0:
return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub) return super(Dot22, self).c_code(node, name, (_x, _y), (_zout, ), sub)
full_code = self.build_gemm_call() % dict(locals(), **sub) full_code = self.build_gemm_call() % dict(locals(), **sub)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论