提交 76426562 authored 作者: Frederic's avatar Frederic

make opt faster by using in2out instead of equilibrium.

上级 11e10897
......@@ -147,7 +147,7 @@ import theano.scalar
from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import local_dimshuffle_lift, in2out
from theano.tensor.opt import in2out, local_dimshuffle_lift
_logger = logging.getLogger('theano.tensor.blas')
......@@ -1758,8 +1758,8 @@ optdb.register('BlasOpt', blas_optdb, 1.7, 'fast_run')
# free-for-all that makes the graph crazy.
blas_optdb.register('local_dot_to_dot22',
EquilibriumOptimizer([local_dot_to_dot22], max_use_ratio=5),
0, 'fast_run')
in2out(local_dot_to_dot22),
0, 'fast_run')
blas_optdb.register('gemm_optimizer',
GemmOptimizer(),
10, 'fast_run')
......@@ -1983,8 +1983,8 @@ def local_dot22_to_dot22scalar(node):
#must happen after gemm as the gemm optimizer don't understant
#dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
EquilibriumOptimizer([local_dot22_to_dot22scalar], max_use_ratio=5),
11, 'fast_run')
in2out(local_dot22_to_dot22scalar),
11, 'fast_run')
#from opt import register_specialize, register_canonicalize
......
from theano import config
from theano.tensor.opt import in2out
from theano.tensor.blas import ldflags, blas_header_text, blas_header_version
from theano.tensor.blas import blas_optdb, optdb, local_optimizer, EquilibriumOptimizer
from theano.tensor.blas import Ger, ger, ger_destructive
......@@ -609,21 +609,14 @@ def make_c_gemv_destructive(node):
####### ####### #######
blas_optdb.register('use_c_blas',
EquilibriumOptimizer([
use_c_ger,
use_c_gemv,
],
max_use_ratio=5),
20, 'fast_run', 'c_blas')
in2out(use_c_ger, use_c_gemv),
20, 'fast_run', 'c_blas')
#print 'BLAS_OPTDB'
#print blas_optdb
# this matches the InplaceBlasOpt defined in blas.py
optdb.register('c_blas_destructive',
EquilibriumOptimizer([
make_c_ger_destructive,
make_c_gemv_destructive,
],
failure_callback=EquilibriumOptimizer.warn_inplace,
max_use_ratio=5),
70.0, 'fast_run', 'inplace', 'c_blas')
in2out(make_c_ger_destructive,
make_c_gemv_destructive,
name="c_blas_destructive"),
70.0, 'fast_run', 'inplace', 'c_blas')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论