提交 ceb89e2d authored 作者: James Bergstra's avatar James Bergstra

Created new Blas optimization phase. Inserted it between stabilization and

specialization. It inserts a new op called PseudoGemm. This Op has the same signature as Gemm but does not work inplace. Another optimization comes later in the pipeline and swaps PseudoGemm for Gemm
上级 dcd0260f
...@@ -6,10 +6,9 @@ import numpy.distutils ...@@ -6,10 +6,9 @@ import numpy.distutils
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler, from theano.gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, Optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError, toolbox) InconsistencyError, toolbox, SequenceDB, EquilibriumOptimizer)
from theano.printing import pprint, FunctionPrinter from theano.printing import pprint, FunctionPrinter
from theano.tensor.opt import register_specialize, out2in, insert_inplace_optimizer from theano.compile.mode import optdb
# opt.py
import basic as T import basic as T
...@@ -30,7 +29,6 @@ AddConfigVar('blas.ldflags', ...@@ -30,7 +29,6 @@ AddConfigVar('blas.ldflags',
"lib[s] to include for [Fortran] level-3 blas implementation", "lib[s] to include for [Fortran] level-3 blas implementation",
StrParam(default_blas_ldflags())) StrParam(default_blas_ldflags()))
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
_logger.setLevel(logging.WARN) _logger.setLevel(logging.WARN)
def debug(*msg): _logger.debug(' '.join(str(m) for m in msg)) def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
...@@ -391,12 +389,22 @@ class Gemm(GemmRelated): ...@@ -391,12 +389,22 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) + self.build_gemm_version() return (1,) + self.build_gemm_version()
gemm = Gemm() class PseudoGemm(Op):
# should be replaced by Gemm
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args):
inputs = [T.as_tensor_variable(i) for i in args]
return Apply(self, inputs, [inputs[0].type()])
def perform(self, node, (z, a, x, y, b), (zout, )):
zout[0] = a * numpy.dot(x,y) + b * z
gemm = PseudoGemm()
gemm_inplace = Gemm()
pprint.assign(gemm, FunctionPrinter('gemm')) pprint.assign(gemm, FunctionPrinter('gemm'))
pprint.assign(gemm_inplace, FunctionPrinter('gemm_inplace'))
def res_is_a(node, op, maxclients=None): def res_is_a(node, op, maxclients=None):
if maxclients is not None: if maxclients is not None:
retval = (len(node.clients) <= maxclients) retval = (len(node.clients) <= maxclients)
...@@ -597,6 +605,7 @@ class GemmOptimizer(Optimizer): ...@@ -597,6 +605,7 @@ class GemmOptimizer(Optimizer):
while did_something: while did_something:
nodelist = list(env.toposort()) nodelist = list(env.toposort())
did_something = False did_something = False
nodelist.reverse()
for node in nodelist: for node in nodelist:
new_outputs = _gemm_from_node(node) new_outputs = _gemm_from_node(node)
if new_outputs: if new_outputs:
...@@ -611,10 +620,6 @@ class GemmOptimizer(Optimizer): ...@@ -611,10 +620,6 @@ class GemmOptimizer(Optimizer):
#TODO: retry other applications of gemm (see comment in _gemm_from_node #TODO: retry other applications of gemm (see comment in _gemm_from_node
pass pass
#neede to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
compile.optdb.register('inplace_gemm', GemmOptimizer(), 70.00, 'fast_run', 'inplace', 'gemm')
class Dot22(GemmRelated): class Dot22(GemmRelated):
"""Compute a matrix-matrix product. """Compute a matrix-matrix product.
This is a specialization of the more general Dot() This is a specialization of the more general Dot()
...@@ -689,5 +694,34 @@ def local_dot_to_dot22(node): ...@@ -689,5 +694,34 @@ def local_dot_to_dot22(node):
info('Not optimizing dot with inputs', x, y, x.type, y.type) info('Not optimizing dot with inputs', x, y, x.type, y.type)
else: else:
return False return False
register_specialize(local_dot_to_dot22)
@local_optimizer([gemm])
def local_inplace_gemm(node):
if node.op == gemm:
return [gemm_inplace(*node.inputs)]
#################################
#
# Set up the BlasOpt optimizer
#
#################################
blas_optdb = SequenceDB()
# run after numerical stability optimizations (1.5)
optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# run before specialize (2.0) because specialize is basically a free-for-all that makes the
# graph crazy.
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run')
blas_optdb.register('local_dot_to_gemm', GemmOptimizer(), 10, 'fast_run')
# After destroyhandler is in but before we try to make elemwise things inplace
# Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the fusion of elemwise(step 71)
optdb.register('InplaceBlasOpt',
EquilibriumOptimizer([local_inplace_gemm], max_use_ratio=5),
70.0, 'fast_run', 'inplace')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论