提交 d3088260 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Update docs to reflect batches and add some fallback code to add batches of 1 to…

Update docs to reflect batches and add some fallback code to add batches of 1 to non-batched version.
上级 42f4cb3e
...@@ -690,13 +690,19 @@ def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx): ...@@ -690,13 +690,19 @@ def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
""" """
var: shape, comment var: shape, comment
W: (iBlocks, oBlocks, iSize, oSize), weight matrix W: (iBlocks, oBlocks, iSize, oSize), weight matrix
h: (iWin, iSize), input from lower layer (sparse) h: (batch, iWin, iSize), input from lower layer (sparse)
inputIdx: (iWin,), indexes of the input blocks inputIdx: (batch, iWin), indexes of the input blocks
b: (oBlocks, oSize), bias vector b: (oBlocks, oSize), bias vector
outputIdx: (oWin,), indexes of the output blocks outputIdx: (batch, oWin), indexes of the output blocks
returns (oBlocks, oSize), dot(W[i, j], h[i]) + b[j] returns (oBlocks, oSize), dot(W[i, j], h[i]) + b[j]
but b[j] is only added once but b[j] is only added once
""" """
assert inputIdx.ndim == h.ndim - 1
assert outputIdx.ndim == inputIdx.ndim
if h.ndim == 2:
h = h.dimshuffle('x', 0, 1)
inputIdx = inputIdx.dimshuffle('x', 0)
outputIdx = outputIdx.dimshuffle('x', 0)
return sparse_block_gemv_ss(b.take(outputIdx, axis=0), W, h, return sparse_block_gemv_ss(b.take(outputIdx, axis=0), W, h,
inputIdx, outputIdx) inputIdx, outputIdx)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论