提交 45af1986 authored 作者: Rami Al-Rfou's avatar Rami Al-Rfou

speed up sparse construct op by 9x

上级 ead77681
......@@ -7,6 +7,7 @@ import warnings
from itertools import izip
import numpy
from numpy.lib.stride_tricks import as_strided
import scipy.sparse as ssparse
#from copy import copy as python_copy
......@@ -6348,36 +6349,16 @@ class AdvancedSubtensor1(Op):
return [ilist + x[1:]]
class ConstructSparse(Op):
"""Increments a subtensor using advanced slicing (list of index)"""
def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace
self.set_instead_of_inc = set_instead_of_inc
if inplace:
self.destroy_map = {0: [0]}
self.row = []
self.col = []
self.data = []
print "Sparse Adv Indexing solution"
"""Construct a sparse matrix out of a list of 2-D matrix rows"""
def __hash__(self):
return hash((type(self), self.inplace, self.set_instead_of_inc))
return hash((type(self)))
def __eq__(self, other):
return (type(self) == type(other)
and self.inplace == other.inplace
and self.set_instead_of_inc == other.set_instead_of_inc)
return (type(self) == type(other))
def __str__(self):
if self.inplace:
msg = "inplace"
else:
msg = "no_inplace"
if self.set_instead_of_inc:
msg += ",set"
else:
msg += ",inc"
return self.__class__.__name__ + "{%s}" % msg
return self.__class__.__name__
def make_node(self, x, y, ilist):
......@@ -6408,19 +6389,14 @@ class ConstructSparse(Op):
def perform(self, node, inp, out_):
x, values, idx = inp
out, = out_
row_ = self.row
data_ = self.data
width = len(values[0])
row_ = []
data_ = []
col = range(width) * len(idx)
i = 0
for j in idx:
row_.extend([j]*width)
data_.extend(values[i])
i += 1
sparse_values = ssparse.coo_matrix((data_, (row_, col)), shape=x.shape, dtype=x.dtype)
out[0] = sparse_values.tocsc()
rows, cols = values.shape
assert rows == len(idx)
indptr = numpy.arange(cols+1) * rows
indices = as_strided(idx,
strides=(0, idx.strides[0]),
shape=(cols, idx.shape[0])).flatten()
data = values.T.flatten()
out[0] = ssparse.csc_matrix((data,indices,indptr), shape=x.shape, dtype=x.dtype)
def infer_shape(self, node, ishapes):
x, y, ilist = ishapes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论