提交 a05a98e0 authored 作者: goodfeli's avatar goodfeli

Merge pull request #705 from nouiz/opt

Opt
...@@ -147,7 +147,7 @@ class ReplaceValidate(History, Validator): ...@@ -147,7 +147,7 @@ class ReplaceValidate(History, Validator):
return chk return chk
def replace_all_validate_remove(self, env, replacements, def replace_all_validate_remove(self, env, replacements,
remove, reason=None): remove, reason=None, warn=True):
"""As replace_all_validate, revert the replacement if the ops """As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning. in the list remove are still in the graph. It also print a warning.
...@@ -156,15 +156,16 @@ class ReplaceValidate(History, Validator): ...@@ -156,15 +156,16 @@ class ReplaceValidate(History, Validator):
for rm in remove: for rm in remove:
if rm in env.nodes or rm in env.variables: if rm in env.nodes or rm in env.variables:
env.revert(chk) env.revert(chk)
out = sys.stderr if warn:
print >> out, ( out = sys.stderr
"WARNING: An optimization wanted to replace a Variable" print >> out, (
" in the graph, but the replacement for it doesn't" "WARNING: An optimization wanted to replace a Variable"
" remove it. We disabled the optimization." " in the graph, but the replacement for it doesn't"
" Your function runs correctly, but it would be" " remove it. We disabled the optimization."
" appreciated if you submit this problem to the mailing" " Your function runs correctly, but it would be"
" list theano-users so that we can fix it.") " appreciated if you submit this problem to the"
print >> out, reason, replacements " mailing list theano-users so that we can fix it.")
print >> out, reason, replacements
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
......
...@@ -127,6 +127,7 @@ import copy ...@@ -127,6 +127,7 @@ import copy
import logging import logging
import os import os
import sys import sys
import time
import numpy import numpy
import numpy.distutils import numpy.distutils
...@@ -1289,11 +1290,16 @@ def _gemm_from_node2(node): ...@@ -1289,11 +1290,16 @@ def _gemm_from_node2(node):
""" """
lst = [] lst = []
t0 = time.time()
_gemm_canonicalize(node.outputs[0], 1.0, lst, 0) _gemm_canonicalize(node.outputs[0], 1.0, lst, 0)
t1 = time.time()
#print "GEMM CANON", lst #print "GEMM CANON", lst
if len(lst) > 1: if len(lst) > 1:
lst = _factor_canonicalized(lst) lst = _factor_canonicalized(lst)
t2 = time.time()
rval = _gemm_from_factored_list(lst) rval = _gemm_from_factored_list(lst)
t3 = time.time()
# It can happen that _factor_canonicalized and # It can happen that _factor_canonicalized and
# _gemm_from_factored_list return a node with an incorrect # _gemm_from_factored_list return a node with an incorrect
...@@ -1305,7 +1311,9 @@ def _gemm_from_node2(node): ...@@ -1305,7 +1311,9 @@ def _gemm_from_node2(node):
# but never made it into a trac ticket. # but never made it into a trac ticket.
if rval and (rval[0][0].type == node.outputs[0].type): if rval and (rval[0][0].type == node.outputs[0].type):
return rval return rval, t1 - t0, t2 - t1, t3 - t2
return None, t1 - t0, 0, 0
class GemmOptimizer(Optimizer): class GemmOptimizer(Optimizer):
...@@ -1319,14 +1327,38 @@ class GemmOptimizer(Optimizer): ...@@ -1319,14 +1327,38 @@ class GemmOptimizer(Optimizer):
def apply(self, env): def apply(self, env):
did_something = True did_something = True
nb_iter = 0
nb_replacement = 0
nb_replacement_didn_t_remove = 0
nb_inconsistency_make = 0
nb_inconsistency_replace = 0
time_canonicalize = 0
time_factor_can = 0
time_factor_list = 0
time_toposort = 0
while did_something: while did_something:
t0 = time.time()
nodelist = list(env.toposort()) nodelist = list(env.toposort())
time_toposort += time.time() - t0
did_something = False did_something = False
nodelist.reverse() nodelist.reverse()
for node in nodelist: for node in nodelist:
if not (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op,
(theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))):
continue
if not node in env.nodes:
# This mean that we already removed this node from
# the graph
continue
try: try:
new_outputs = _gemm_from_node2(node) new_outputs, time1, time2, time3 = _gemm_from_node2(node)
time_canonicalize += time1
time_factor_can += time2
time_factor_list += time3
except InconsistencyError, e: except InconsistencyError, e:
nb_inconsistency_make += 1
continue continue
if new_outputs: if new_outputs:
new_outputs, old_dot22 = new_outputs new_outputs, old_dot22 = new_outputs
...@@ -1335,16 +1367,39 @@ class GemmOptimizer(Optimizer): ...@@ -1335,16 +1367,39 @@ class GemmOptimizer(Optimizer):
env.replace_all_validate_remove( env.replace_all_validate_remove(
zip(node.outputs, new_outputs), zip(node.outputs, new_outputs),
[old_dot22], [old_dot22],
reason='GemmOptimizer' reason='GemmOptimizer',
warn=nb_replacement_didn_t_remove == 0
) )
did_something = True did_something = True
break nb_replacement += 1
except InconsistencyError, e: except InconsistencyError, e:
# TODO: retry other applications of gemm (see comment # TODO: retry other applications of gemm (see comment
# in _gemm_from_node) # in _gemm_from_node)
nb_inconsistency_replace += 1
pass pass
except ReplacementDidntRemovedError, e: except ReplacementDidntRemovedError, e:
nb_replacement_didn_t_remove += 1
pass pass
nb_iter += 1
return (self, nb_iter, nb_replacement, nb_replacement_didn_t_remove,
nb_inconsistency_make, nb_inconsistency_replace,
time_canonicalize, time_factor_can,
time_factor_list, time_toposort)
@staticmethod
def print_profile(stream, prof, level=0):
blanc = (' ' * level)
#1946.912556s - ('gemm_optimizer', 'GemmOptimizer', 1)
print >> stream, blanc, "GemmOptimizer"
print >> stream, blanc, " nb_iter", prof[1]
print >> stream, blanc, " nb_replacement", prof[2]
print >> stream, blanc, " nb_replacement_didn_t_remove", prof[3]
print >> stream, blanc, " nb_inconsistency_make", prof[4]
print >> stream, blanc, " nb_inconsistency_replace", prof[5]
print >> stream, blanc, " time_canonicalize", prof[6]
print >> stream, blanc, " time_factor_can", prof[7]
print >> stream, blanc, " time_factor_list", prof[8]
print >> stream, blanc, " time_toposort", prof[9]
class Dot22(GemmRelated): class Dot22(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论