提交 5b1bb8a7 authored 作者: Frederic's avatar Frederic

Use only one pass over the toposort for some blas inplace opt.

上级 a5ffbdfd
...@@ -14,8 +14,7 @@ import theano.ifelse ...@@ -14,8 +14,7 @@ import theano.ifelse
from theano.compile import optdb from theano.compile import optdb
from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB, from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer, toolbox, DestroyHandler, Optimizer, toolbox, DestroyHandler)
EquilibriumOptimizer)
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.sandbox.cuda.basic_ops import ( from theano.sandbox.cuda.basic_ops import (
device_properties, gpu_eye, device_properties, gpu_eye,
...@@ -1199,12 +1198,9 @@ def local_inplace_ger(node): ...@@ -1199,12 +1198,9 @@ def local_inplace_ger(node):
# Also, need to make the gemm optimisation(step 70) happen before the fusion of # Also, need to make the gemm optimisation(step 70) happen before the fusion of
# elemwise(step 71) # elemwise(step 71)
optdb.register('InplaceGpuBlasOpt', optdb.register('InplaceGpuBlasOpt',
EquilibriumOptimizer([local_inplace_gemm, tensor.opt.in2out(gof.LocalOptGroup(local_inplace_gemm,
local_inplace_gemv, local_inplace_gemv,
local_inplace_ger, local_inplace_ger)),
],
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5),
70.0, 'fast_run', 'inplace', 'gpu') 70.0, 'fast_run', 'inplace', 'gpu')
......
...@@ -135,7 +135,7 @@ import numpy.distutils.__config__ ...@@ -135,7 +135,7 @@ import numpy.distutils.__config__
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
from theano.gof import (utils, Op, view_roots, DestroyHandler, from theano.gof import (utils, Op, view_roots, DestroyHandler,
local_optimizer, Optimizer, local_optimizer, Optimizer, LocalOptGroup,
InconsistencyError, toolbox, SequenceDB, InconsistencyError, toolbox, SequenceDB,
EquilibriumOptimizer, Apply, EquilibriumOptimizer, Apply,
ReplacementDidntRemovedError) ReplacementDidntRemovedError)
...@@ -147,7 +147,7 @@ import theano.scalar ...@@ -147,7 +147,7 @@ import theano.scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift from theano.tensor.opt import local_dimshuffle_lift, in2out
_logger = logging.getLogger('theano.tensor.blas') _logger = logging.getLogger('theano.tensor.blas')
...@@ -1777,10 +1777,9 @@ blas_optdb.register('local_gemm_to_gemv', ...@@ -1777,10 +1777,9 @@ blas_optdb.register('local_gemm_to_gemv',
# Try to make gemm inplace # Try to make gemm inplace
# Also, need to make the gemm optimisation(step 70) happen before the # Also, need to make the gemm optimisation(step 70) happen before the
# fusion of elemwise(step 71) # fusion of elemwise(step 71)
blas_opt_inplace = EquilibriumOptimizer( blas_opt_inplace = in2out(LocalOptGroup(local_inplace_gemm,
[local_inplace_gemm, local_inplace_gemv, local_inplace_ger], local_inplace_gemv,
failure_callback=EquilibriumOptimizer.warn_inplace, local_inplace_ger))
max_use_ratio=5)
optdb.register('InplaceBlasOpt', optdb.register('InplaceBlasOpt',
blas_opt_inplace, blas_opt_inplace,
70.0, 'fast_run', 'inplace') 70.0, 'fast_run', 'inplace')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论