提交 c9da1d30 authored 作者: James Bergstra's avatar James Bergstra

sparse - corrected broadcastable pattern of StructuredDot

上级 8cfc382a
......@@ -741,7 +741,9 @@ class StructuredDot(gof.Op):
if type(a) is not SparseVariable and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseVariable or SparseConstant');
dtype_out = scalar.upcast(a.type.dtype, b.type.dtype)
return gof.Apply(self, [a,b], [tensor.tensor(dtype_out, (False, False))])
if b.type.ndim != 2:
raise NotImplementedError('non-matrix b')
return gof.Apply(self, [a,b], [tensor.tensor(dtype_out, (False, b.type.broadcastable[1]))])
def perform(self, node, (a,b), (out,)):
if a.shape[1] != b.shape[0]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论