提交 1d042d99 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add infer_shape for the blocksparse ops and pass though the broadcastable flags.

上级 2479cc7e
...@@ -91,10 +91,7 @@ class SparseBlockGemv(Op): ...@@ -91,10 +91,7 @@ class SparseBlockGemv(Op):
assert inputIdx.type.dtype in discrete_dtypes assert inputIdx.type.dtype in discrete_dtypes
assert outputIdx.type.dtype in discrete_dtypes assert outputIdx.type.dtype in discrete_dtypes
output = o.type.__class__(dtype=o.type.dtype, return Apply(self, [o, W, h, inputIdx, outputIdx], [o.type()])
broadcastable=(False,) * o.ndim)()
return Apply(self, [o, W, h, inputIdx, outputIdx], [output])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
o, W, h, iIdx, oIdx = inp[:5] o, W, h, iIdx, oIdx = inp[:5]
...@@ -111,6 +108,9 @@ class SparseBlockGemv(Op): ...@@ -111,6 +108,9 @@ class SparseBlockGemv(Op):
o[b, j, :] += numpy.dot(h[b, i], w) o[b, j, :] += numpy.dot(h[b, i], w)
out_[0][0] = o out_[0][0] = o
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
def grad(self, inputs, grads): def grad(self, inputs, grads):
o, W, h, inputIdx, outputIdx = inputs o, W, h, inputIdx, outputIdx = inputs
go = grads[0] go = grads[0]
...@@ -192,11 +192,11 @@ class SparseBlockOuter(Op): ...@@ -192,11 +192,11 @@ class SparseBlockOuter(Op):
if alpha is None: if alpha is None:
alpha = one alpha = one
output = o.type.__class__(dtype=o.type.dtype,
broadcastable=(False,) * o.ndim)()
return Apply(self, [o, x, y, xIdx, yIdx, alpha], return Apply(self, [o, x, y, xIdx, yIdx, alpha],
[output]) [o.type()])
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
o, x, y, xIdx, yIdx, alpha = inp[:6] o, x, y, xIdx, yIdx, alpha = inp[:6]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论