提交 ae20b6ab authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Sparse comparison: Fix type violation in perform method

上级 d4ea7573
...@@ -973,7 +973,12 @@ class __ComparisonOpSS(Op): ...@@ -973,7 +973,12 @@ class __ComparisonOpSS(Op):
(out,) = outputs (out,) = outputs
assert psb._is_sparse(x) and psb._is_sparse(y) assert psb._is_sparse(x) and psb._is_sparse(y)
assert x.shape == y.shape assert x.shape == y.shape
out[0] = self.comparison(x, y).astype("uint8") # FIXME: Scipy csc > csc outputs csr format, but make_node assumes it will be the same as inputs
# Casting to respect make_node, but this is very inefficient
# TODO: Why not go with default bool?
out[0] = (
self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format)
)
def infer_shape(self, fgraph, node, ins_shapes): def infer_shape(self, fgraph, node, ins_shapes):
return [ins_shapes[0]] return [ins_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论