提交 26cfa4a8 authored 作者: Frederic's avatar Frederic

Print the gemm warning only once per process, not once per compilation.

上级 977b8b75
...@@ -1350,6 +1350,7 @@ class GemmOptimizer(Optimizer): ...@@ -1350,6 +1350,7 @@ class GemmOptimizer(Optimizer):
"""Graph optimizer for inserting Gemm operations""" """Graph optimizer for inserting Gemm operations"""
def __init__(self): def __init__(self):
Optimizer.__init__(self) Optimizer.__init__(self)
self.warned = False
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
...@@ -1398,7 +1399,7 @@ class GemmOptimizer(Optimizer): ...@@ -1398,7 +1399,7 @@ class GemmOptimizer(Optimizer):
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
[old_dot22], [old_dot22],
reason='GemmOptimizer', reason='GemmOptimizer',
warn=nb_replacement_didn_t_remove == 0 warn=not self.warned
) )
did_something = True did_something = True
nb_replacement += 1 nb_replacement += 1
...@@ -1406,10 +1407,9 @@ class GemmOptimizer(Optimizer): ...@@ -1406,10 +1407,9 @@ class GemmOptimizer(Optimizer):
# TODO: retry other applications of gemm (see comment # TODO: retry other applications of gemm (see comment
# in _gemm_from_node) # in _gemm_from_node)
nb_inconsistency_replace += 1 nb_inconsistency_replace += 1
pass
except ReplacementDidntRemovedError, e: except ReplacementDidntRemovedError, e:
nb_replacement_didn_t_remove += 1 nb_replacement_didn_t_remove += 1
pass self.warned = True
nb_iter += 1 nb_iter += 1
return (self, nb_iter, nb_replacement, nb_replacement_didn_t_remove, return (self, nb_iter, nb_replacement, nb_replacement_didn_t_remove,
nb_inconsistency_make, nb_inconsistency_replace, nb_inconsistency_make, nb_inconsistency_replace,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论