提交 19c4443e authored 作者: James Bergstra's avatar James Bergstra

added a JOSEPHS_BUG_SOLVED global variable to tensor.blas that disables dot22…

added a JOSEPHS_BUG_SOLVED global variable to tensor.blas that disables dot22 and gemm optimizations
上级 2af541a0
...@@ -16,6 +16,7 @@ from .. import compile #to register the optimizer built by this file ...@@ -16,6 +16,7 @@ from .. import compile #to register the optimizer built by this file
from .blas_headers import cblas_header_text, blas_header_text from .blas_headers import cblas_header_text, blas_header_text
JOSEPHS_BUG_SOLVED = False
@utils.memoize @utils.memoize
...@@ -416,7 +417,8 @@ def local_dot_to_dot22(node): ...@@ -416,7 +417,8 @@ def local_dot_to_dot22(node):
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
else: else:
return False return False
register_specialize(local_dot_to_dot22) if JOSEPHS_BUG_SOLVED:
register_specialize(local_dot_to_dot22)
def _is_a(node, op, maxclients=None): def _is_a(node, op, maxclients=None):
return node.owner \ return node.owner \
...@@ -527,7 +529,8 @@ def local_sub_to_gemm(node): ...@@ -527,7 +529,8 @@ def local_sub_to_gemm(node):
rval = beta_L_plus_alpha_M(sL, mL, -sR, mR) rval = beta_L_plus_alpha_M(sL, mL, -sR, mR)
return rval return rval
return False return False
register_specialize(local_sub_to_gemm) if JOSEPHS_BUG_SOLVED:
register_specialize(local_sub_to_gemm)
@local_optimizer([T.add]) @local_optimizer([T.add])
def local_add_to_gemm(node): def local_add_to_gemm(node):
...@@ -562,5 +565,6 @@ def local_add_to_gemm(node): ...@@ -562,5 +565,6 @@ def local_add_to_gemm(node):
[input for k, input in enumerate(node.inputs) if k not in (i,j)] [input for k, input in enumerate(node.inputs) if k not in (i,j)]
return [T.add( *(inputs_without_ij + rval))] return [T.add( *(inputs_without_ij + rval))]
return False return False
register_specialize(local_add_to_gemm) if JOSEPHS_BUG_SOLVED:
register_specialize(local_add_to_gemm)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论