提交 b2fb02b1 authored 作者: James Bergstra's avatar James Bergstra

sparse - fixed broadcastable pattern of specialized structured_dot ops

上级 c9da1d30
...@@ -810,7 +810,7 @@ class StructuredDotCSC(gof.Op): ...@@ -810,7 +810,7 @@ class StructuredDotCSC(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
[tensor.tensor(dtype_out, (False, False))]) [tensor.tensor(dtype_out, (False, b.type.broadcastable[1]))])
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, a_nrows, b), (out,)):
...@@ -976,7 +976,7 @@ class StructuredDotCSR(gof.Op): ...@@ -976,7 +976,7 @@ class StructuredDotCSR(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, b): def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
[tensor.tensor(self.dtype_out, (False, False))]) [tensor.tensor(self.dtype_out, (False, b.type.broadcastable[1]))])
return r return r
def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)): def perform(self, node, (a_val, a_ind, a_ptr, b), (out,)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论