提交 f3c99f51 authored 作者: Frederic's avatar Frederic

Make inplace opt for cger and cgemv faster.

上级 dbd301c1
...@@ -252,6 +252,8 @@ class CGer(BaseBLAS, Ger): ...@@ -252,6 +252,8 @@ class CGer(BaseBLAS, Ger):
def c_code_cache_version(self): def c_code_cache_version(self):
return (8, blas_header_version()) return (8, blas_header_version())
cger_inplace = CGer(True)
cger_no_inplace = CGer(False)
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
...@@ -269,8 +271,8 @@ def use_c_ger(node): ...@@ -269,8 +271,8 @@ def use_c_ger(node):
@local_optimizer([CGer(False)]) @local_optimizer([CGer(False)])
def make_c_ger_destructive(node): def make_c_ger_destructive(node):
if node.op == CGer(False): if node.op == cger_no_inplace:
return [CGer(True)(*node.inputs)] return [cger_inplace(*node.inputs)]
####### ####### ####### ####### ####### #######
...@@ -579,6 +581,8 @@ class CGemv(BaseBLAS, Gemv): ...@@ -579,6 +581,8 @@ class CGemv(BaseBLAS, Gemv):
def c_code_cache_version(self): def c_code_cache_version(self):
return (10, blas_header_version()) return (10, blas_header_version())
cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)
@local_optimizer([gemv_inplace, gemv_no_inplace]) @local_optimizer([gemv_inplace, gemv_no_inplace])
...@@ -596,8 +600,8 @@ def use_c_gemv(node): ...@@ -596,8 +600,8 @@ def use_c_gemv(node):
@local_optimizer([CGemv(inplace=False)]) @local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node): def make_c_gemv_destructive(node):
if node.op == CGemv(inplace=False): if node.op == gemv_no_inplace:
return [CGemv(inplace=True)(*node.inputs)] return [gemv_inplace(*node.inputs)]
####### ####### ####### ####### ####### #######
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论