提交 0e8f48e5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2136 from abergeron/new_block

Fix blocksparse optimization that could give bad results in some cases
...@@ -321,7 +321,6 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -321,7 +321,6 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
o, W, h, inputIdx, outputIdx = inputs o, W, h, inputIdx, outputIdx = inputs
go = grads[0] go = grads[0]
# might revise that interface to not have a huge output
Wgrad = sparse_block_outer_ss(W.zeros_like(), Wgrad = sparse_block_outer_ss(W.zeros_like(),
h, go, inputIdx, outputIdx) h, go, inputIdx, outputIdx)
hgrad = sparse_block_gemv_ss(h.zeros_like(), hgrad = sparse_block_gemv_ss(h.zeros_like(),
...@@ -682,9 +681,10 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr) ...@@ -682,9 +681,10 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr)
@opt.register_opt() @opt.register_opt()
@opt.local_optimizer([GpuElemwise]) @opt.local_optimizer([GpuElemwise])
def local_merge_blocksparse_beta(node): def local_merge_blocksparse_output(node):
if (isinstance(node.op, GpuElemwise) and if (isinstance(node.op, GpuElemwise) and
node.op.scalar_op == scalar.sub and (node.op.scalar_op == scalar.sub or
node.op.scalar_op == scalar.add) and
node.nin == 2): node.nin == 2):
ger = grab_ger(node.inputs[0]) ger = grab_ger(node.inputs[0])
W = node.inputs[1] W = node.inputs[1]
...@@ -693,8 +693,14 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr) ...@@ -693,8 +693,14 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr)
W = node.inputs[0] W = node.inputs[0]
if ger is None: if ger is None:
return None return None
if node.op.scalar_op == scalar.sub:
alpha = -ger.inputs[5]
W = W - ger.inputs[0]
else:
alpha = ger.inputs[5]
W = W + ger.inputs[0]
return [sparse_block_outer_ss(*([W] + ger.inputs[1:5] + return [sparse_block_outer_ss(*([W] + ger.inputs[1:5] +
[-ger.inputs[5]]))] [alpha]))]
def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx): def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论