提交 11ba80b9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Also use the other arg of sub as output.

上级 3f1364db
...@@ -321,7 +321,6 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -321,7 +321,6 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
o, W, h, inputIdx, outputIdx = inputs o, W, h, inputIdx, outputIdx = inputs
go = grads[0] go = grads[0]
# might revise that interface to not have a huge output
Wgrad = sparse_block_outer_ss(W.zeros_like(), Wgrad = sparse_block_outer_ss(W.zeros_like(),
h, go, inputIdx, outputIdx) h, go, inputIdx, outputIdx)
hgrad = sparse_block_gemv_ss(h.zeros_like(), hgrad = sparse_block_gemv_ss(h.zeros_like(),
...@@ -682,9 +681,9 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr) ...@@ -682,9 +681,9 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr)
@opt.register_opt() @opt.register_opt()
@opt.local_optimizer([GpuElemwise]) @opt.local_optimizer([GpuElemwise])
def local_merge_blocksparse_beta(node): def local_merge_blocksparse_output(node):
if (isinstance(node.op, GpuElemwise) and if (isinstance(node.op, GpuElemwise) and
node.op.scalar_op == scalar.sub and (node.op.scalar_op == scalar.sub or node.scalar_op == scalar.add) and
node.nin == 2): node.nin == 2):
ger = grab_ger(node.inputs[0]) ger = grab_ger(node.inputs[0])
W = node.inputs[1] W = node.inputs[1]
...@@ -693,8 +692,11 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr) ...@@ -693,8 +692,11 @@ GpuElemwise{mul}(lr, SparseBlockOuterSS) -> SparseBlockOuterSS(..., alpha=lr)
W = node.inputs[0] W = node.inputs[0]
if ger is None: if ger is None:
return None return None
alpha = get.inputs[5]
if node.op.scalar_op == scalar.sub:
alpha = -alpha
return [sparse_block_outer_ss(*([W] + ger.inputs[1:5] + return [sparse_block_outer_ss(*([W] + ger.inputs[1:5] +
[-ger.inputs[5]]))] [alpha]))]
def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx): def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论