提交 b2fb02b1 authored 作者: James Bergstra's avatar James Bergstra

sparse - fixed broadcastable pattern of specialized structured_dot ops

上级 c9da1d30
......@@ -810,7 +810,7 @@ class StructuredDotCSC(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
[tensor.tensor(dtype_out, (False, False))])
[tensor.tensor(dtype_out, (False, b.type.broadcastable[1]))])
return r
def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)):
......@@ -976,7 +976,7 @@ class StructuredDotCSR(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
[tensor.tensor(self.dtype_out, (False, False))])
[tensor.tensor(self.dtype_out, (False, b.type.broadcastable[1]))])
return r
def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论