提交 e2fa5912 authored 作者: Frederic Bastien's avatar Frederic Bastien

Get some inplace stats

上级 5cbbcb72
...@@ -5,6 +5,7 @@ Tensor optimizations addressing the ops in basic.py. ...@@ -5,6 +5,7 @@ Tensor optimizations addressing the ops in basic.py.
# TODO: intelligent merge for mul/add # TODO: intelligent merge for mul/add
# TODO: 0*x -> 0 # TODO: 0*x -> 0
from collections import defaultdict
import logging import logging
import itertools import itertools
import operator import operator
...@@ -187,6 +188,12 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -187,6 +188,12 @@ def inplace_elemwise_optimizer_op(OP):
# the solution is also applicable there. # the solution is also applicable there.
# We execute `validate` after this number of change. # We execute `validate` after this number of change.
prof = {'node_before': len(fgraph.apply_nodes),
'nb_call_replace': 0,
'nb_call_validate': 0,
'nb_inconsistent': 0,
'ndim': defaultdict(lambda: 0)}
check_each_change = config.tensor.insert_inplace_optimizer_validate_nb check_each_change = config.tensor.insert_inplace_optimizer_validate_nb
if check_each_change == -1: if check_each_change == -1:
if len(fgraph.apply_nodes) > 500: if len(fgraph.apply_nodes) > 500:
...@@ -332,14 +339,18 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -332,14 +339,18 @@ def inplace_elemwise_optimizer_op(OP):
new_node = new_outputs[0].owner new_node = new_outputs[0].owner
for r, new_r in zip(node.outputs, new_outputs): for r, new_r in zip(node.outputs, new_outputs):
prof['nb_call_replace'] += 1
fgraph.replace(r, new_r, fgraph.replace(r, new_r,
reason="inplace_elemwise_optimizer") reason="inplace_elemwise_optimizer")
nb_change_no_validate += 1 nb_change_no_validate += 1
prof['ndim'][node.ndim] += 1
if nb_change_no_validate >= check_each_change: if nb_change_no_validate >= check_each_change:
prof['nb_call_validate'] += 1
fgraph.validate() fgraph.validate()
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
nb_change_no_validate = 0 nb_change_no_validate = 0
except (ValueError, InconsistencyError) as e: except (ValueError, InconsistencyError) as e:
prof['nb_inconsistent'] += 1
if check_each_change != 1 and not raised_warning: if check_each_change != 1 and not raised_warning:
print(("Some inplace optimization was not " print(("Some inplace optimization was not "
"performed due to unexpected error:"), "performed due to unexpected error:"),
...@@ -362,6 +373,7 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -362,6 +373,7 @@ def inplace_elemwise_optimizer_op(OP):
"performed due to unexpected error"), "performed due to unexpected error"),
file=sys.stderr) file=sys.stderr)
fgraph.revert(chk) fgraph.revert(chk)
print(prof)
return inplace_elemwise_optimizer return inplace_elemwise_optimizer
inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise) inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论