提交 ef21cb58 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add register_inplace function to sandbox.cuda.

上级 9f02e130
...@@ -12,7 +12,7 @@ import warnings ...@@ -12,7 +12,7 @@ import warnings
import theano import theano
from theano.compat import get_unbound_function from theano.compat import get_unbound_function
from theano.compile import optdb from theano.compile import optdb
from theano.gof import EquilibriumDB, SequenceDB from theano.gof import EquilibriumDB, SequenceDB, TopoOptimizer
from theano.gof.cmodule import get_lib_extension from theano.gof.cmodule import get_lib_extension
from theano.gof.compilelock import get_lock, release_lock from theano.gof.compilelock import get_lock, release_lock
from theano import config from theano import config
...@@ -40,6 +40,17 @@ def register_opt(*tags, **kwargs): ...@@ -40,6 +40,17 @@ def register_opt(*tags, **kwargs):
return f return f
def register_inplace(*tags, **kwargs):
def f(local_opt):
name = (kwargs and kwargs.pop('name')) or local_opt.__name__
optdb.register(
name, TopoOptimizer(
local_opt, failure_callback=TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace', 'gpu', *tags)
return local_opt
return f
_logger_name = 'theano.sandbox.cuda' _logger_name = 'theano.sandbox.cuda'
_logger = logging.getLogger(_logger_name) _logger = logging.getLogger(_logger_name)
......
...@@ -33,7 +33,7 @@ from theano.sandbox.cuda.blas import (GpuConv, GpuDownsampleFactorMax, ...@@ -33,7 +33,7 @@ from theano.sandbox.cuda.blas import (GpuConv, GpuDownsampleFactorMax,
from theano.sandbox.cuda.nnet import GpuSoftmax from theano.sandbox.cuda.nnet import GpuSoftmax
from theano.sandbox.cuda.opt_util import (alpha_merge, output_merge, from theano.sandbox.cuda.opt_util import (alpha_merge, output_merge,
pad_dims, unpad_dims) pad_dims, unpad_dims)
from theano.sandbox.cuda import gpu_seqopt, register_opt from theano.sandbox.cuda import gpu_seqopt, register_opt, register_inplace
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论