提交 acea5a6c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Disable insertion of CGemv and CGer with complex.

上级 1ddd6f3a
...@@ -203,9 +203,12 @@ class CGer(BaseBLAS, Ger): ...@@ -203,9 +203,12 @@ class CGer(BaseBLAS, Ger):
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
def use_c_ger(node): def use_c_ger(node):
if node.op == ger: # Only float32 and float64 are supported for now.
if (node.op == ger and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(False)(*node.inputs)] return [CGer(False)(*node.inputs)]
if node.op == ger_destructive: if (node.op == ger_destructive and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGer(True)(*node.inputs)] return [CGer(True)(*node.inputs)]
@local_optimizer([CGer(False)]) @local_optimizer([CGer(False)])
...@@ -425,9 +428,12 @@ class CGemv(BaseBLAS, Gemv): ...@@ -425,9 +428,12 @@ class CGemv(BaseBLAS, Gemv):
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
def use_c_gemv(node): def use_c_gemv(node):
if node.op == gemv_no_inplace: # Only float32 and float64 are supported for now.
if (node.op == gemv_no_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=False)(*node.inputs)] return [CGemv(inplace=False)(*node.inputs)]
if node.op == gemv_inplace: if (node.op == gemv_inplace and
node.outputs[0].dtype in ['float32', 'float64']):
return [CGemv(inplace=True)(*node.inputs)] return [CGemv(inplace=True)(*node.inputs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论