提交 ef21cb58 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add register_inplace function to sandbox.cuda.

上级 9f02e130
......@@ -12,7 +12,7 @@ import warnings
import theano
from theano.compat import get_unbound_function
from theano.compile import optdb
from theano.gof import EquilibriumDB, SequenceDB
from theano.gof import EquilibriumDB, SequenceDB, TopoOptimizer
from theano.gof.cmodule import get_lib_extension
from theano.gof.compilelock import get_lock, release_lock
from theano import config
......@@ -40,6 +40,17 @@ def register_opt(*tags, **kwargs):
return f
def register_inplace(*tags, **kwargs):
def f(local_opt):
name = (kwargs and kwargs.pop('name')) or local_opt.__name__
optdb.register(
name, TopoOptimizer(
local_opt, failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace', 'gpu', *tags)
return local_opt
return f
_logger_name = 'theano.sandbox.cuda'
_logger = logging.getLogger(_logger_name)
......
......@@ -33,7 +33,7 @@ from theano.sandbox.cuda.blas import (GpuConv, GpuDownsampleFactorMax,
from theano.sandbox.cuda.nnet import GpuSoftmax
from theano.sandbox.cuda.opt_util import (alpha_merge, output_merge,
pad_dims, unpad_dims)
from theano.sandbox.cuda import gpu_seqopt, register_opt
from theano.sandbox.cuda import gpu_seqopt, register_opt, register_inplace
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论