提交 1c3afdf6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add inplace optimizations.

上级 6b162a92
......@@ -8,7 +8,7 @@ from theano.sandbox.cuda import cuda_available, GpuOp
if cuda_available:
from theano.sandbox.cuda import (basic_ops, CudaNdarrayType,
CudaNdarray)
CudaNdarray, opt)
import theano.misc.pycuda_init
from theano.misc.pycuda_init import pycuda_available
......@@ -76,6 +76,15 @@ class SparseBlockGemvSS(GpuOp):
if self.inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
return "SparseBlockGemvSS%s" % ("{inplace}" if self.inplace else "")
def make_node(self, o, W, h, inputIdx, outputIdx):
o = basic_ops.as_cuda_ndarray_variable(o)
W = basic_ops.as_cuda_ndarray_variable(W)
......@@ -127,7 +136,7 @@ class SparseBlockGemvSS(GpuOp):
sparse_block_gemv_ss = SparseBlockGemvSS(False)
sparse_block_gemv_ss_outer = SparseBlockGemvSS(True)
sparse_block_gemv_ss_inplace = SparseBlockGemvSS(True)
class SparseBlockOuterSS(GpuOp):
......@@ -142,6 +151,9 @@ class SparseBlockOuterSS(GpuOp):
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
return "SparseBlockOuterSS%s" % ("{inplace}" if self.inplace else "")
def make_node(self, o, x, y, xIdx, yIdx):
o = basic_ops.as_cuda_ndarray_variable(o)
x = basic_ops.as_cuda_ndarray_variable(x)
......@@ -166,10 +178,21 @@ class SparseBlockOuterSS(GpuOp):
out[0] = o
sparse_block_outer_ss = SparseBlockOuterSS()
sparse_block_outer_ss = SparseBlockOuterSS(False)
sparse_block_outer_ss_inplace = SparseBlockOuterSS(True)
@opt.register_opt()
@opt.local_optimizer([sparse_block_gemv_ss], inplace=True)
def local_inplace_blocksparse_gemv(node):
if node.op == sparse_block_gemv_ss:
return [sparse_block_gemv_ss_inplace(*node.inputs)]
#############################################################
# All code above this line is unused (except for the imports)
@opt.register_opt()
@opt.local_optimizer([sparse_block_outer_ss], inplace=True)
def local_inplace_blocksparse_outer(node):
if node.op == sparse_block_outer_ss:
return [sparse_block_outer_ss_inplace(*node.inputs)]
def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论