提交 5b1bb8a7 authored 作者: Frederic's avatar Frederic

Use only one pass over the toposort for some blas inplace opt.

上级 a5ffbdfd
......@@ -14,8 +14,7 @@ import theano.ifelse
from theano.compile import optdb
from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer, toolbox, DestroyHandler,
EquilibriumOptimizer)
Optimizer, toolbox, DestroyHandler)
from theano.gof.python25 import all, any
from theano.sandbox.cuda.basic_ops import (
device_properties, gpu_eye,
......@@ -1199,12 +1198,9 @@ def local_inplace_ger(node):
# Also, need to make the gemm optimisation(step 70) happen before the fusion of
# elemwise(step 71)
optdb.register('InplaceGpuBlasOpt',
EquilibriumOptimizer([local_inplace_gemm,
local_inplace_gemv,
local_inplace_ger,
],
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5),
tensor.opt.in2out(gof.LocalOptGroup(local_inplace_gemm,
local_inplace_gemv,
local_inplace_ger)),
70.0, 'fast_run', 'inplace', 'gpu')
......
......@@ -135,7 +135,7 @@ import numpy.distutils.__config__
from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer,
local_optimizer, Optimizer, LocalOptGroup,
InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError)
......@@ -147,7 +147,7 @@ import theano.scalar
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift
from theano.tensor.opt import local_dimshuffle_lift, in2out
_logger = logging.getLogger('theano.tensor.blas')
......@@ -1777,10 +1777,9 @@ blas_optdb.register('local_gemm_to_gemv',
# Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the
# fusion of elemwise(step 71)
blas_opt_inplace = EquilibriumOptimizer(
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger],
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5)
blas_opt_inplace = in2out(LocalOptGroup(local_inplace_gemm,
local_inplace_gemv,
local_inplace_ger))
optdb.register('InplaceBlasOpt',
blas_opt_inplace,
70.0, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论