提交 ceb89e2d authored 作者: James Bergstra's avatar James Bergstra

Created new Blas optimization phase. Inserted it between stabilization and

specialization. It inserts a new op called PseudoGemm. This Op has the same signature as Gemm but does not work inplace. Another optimization comes later in the pipeline and swaps PseudoGemm for Gemm
上级 dcd0260f
......@@ -6,10 +6,9 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox)
InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer)
from theano.printing import pprint, FunctionPrinter
from theano.tensor.opt import register_specialize, out2in, insert_inplace_optimizer
# opt.py
from theano.compile.mode import optdb
import basic as T
......@@ -30,7 +29,6 @@ AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags()))
_logger = logging.getLogger('theano.tensor.blas')
_logger.setLevel(logging.WARN)
def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
......@@ -391,12 +389,22 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
return (1,) + self.build_gemm_version()
gemm = Gemm()
class PseudoGemm(Op):
# should be replaced by Gemm
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args):
inputs = [T.as_tensor_variable(i) for i in args]
return Apply(self, inputs, [inputs[0].type()])
def perform(self, node, (z, a, x, y, b), (zout, )):
zout[0] = a * numpy.dot(x,y) + b * z
gemm = PseudoGemm()
gemm_inplace = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm'))
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
def res_is_a(node, op, maxclients=None):
if maxclients is not None:
retval = (len(node.clients) <= maxclients)
......@@ -597,6 +605,7 @@ class GemmOptimizer(Optimizer):
while did_something:
nodelist = list(env.toposort())
did_something = False
nodelist.reverse()
for node in nodelist:
new_outputs = _gemm_from_node(node)
if new_outputs:
......@@ -611,10 +620,6 @@ class GemmOptimizer(Optimizer):
#TODO: retry other applications of gemm (see comment in _gemm_from_node
pass
#neede to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
compile.optdb.register('inplace_gemm', GemmOptimizer(), 70.00, 'fast_run', 'inplace', 'gemm')
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
This is a specialization of the more general Dot()
......@@ -689,5 +694,34 @@ def local_dot_to_dot22(node):
info('Not optimizing dot with inputs', x, y, x.type, y.type)
else:
return False
register_specialize(local_dot_to_dot22)
@local_optimizer([gemm])
def local_inplace_gemm(node):
if node.op == gemm:
return [gemm_inplace(*node.inputs)]
#################################
#
# Set up the BlasOpt optimizer
#
#################################
blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5)
optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the
# graph crazy.
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run')
blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm], max_use_ratio=5),
70.0, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论