提交 1744a5b1 authored 作者: Tanjay94's avatar Tanjay94

Modified small mistakes in GetItemList and GetItem2Lists.

上级 519f9d26
......@@ -1000,6 +1000,7 @@ csc_from_dense = SparseFromDense('csc')
:return: The same as `x` in a sparse matrix format.
"""
# Indexing
class GetItemList(gof.op.Op):
......@@ -1037,6 +1038,14 @@ class GetItemList(gof.op.Op):
return self.__class__.__name__
get_item_list = GetItemList()
"""Select row of sparse matrix,
returning them as a new sparse matrix.
:param x: Sparse matrix.
:param index: List of rows.
:return: The corresponding rows in `x`.
"""
class GetItemListGrad(gof.op.Op):
......@@ -1048,7 +1057,7 @@ class GetItemListGrad(gof.op.Op):
return hash(type(self))
def infer_shape(self, node, shapes):
return [(shapes[0][0], shapes[0][1])]
return [(shapes[0])]
def make_node(self, x, index, gz):
x = as_sparse_variable(x)
......@@ -1060,6 +1069,11 @@ class GetItemListGrad(gof.op.Op):
ind = tensor.as_tensor_variable(index)
assert ind.ndim == 1
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if not scipy_ver >= [0, 13]:
raise NotImplemented("Scipy version is to old")
return gof.Apply(self, [x, ind, gz], [x.type()])
def perform(self, node, inp, (out, )):
......@@ -1071,11 +1085,8 @@ class GetItemListGrad(gof.op.Op):
y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1]))
else:
y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1]))
c = 0
for a in range(0, gz.shape[0]):
for b in range(0, gz.shape[1]):
y[(indices[c], b)] = gz[(a, b)]
c = c + 1
for a in range(0, len(indices)):
y[indices[a]] = gz[a]
out[0] = y
......@@ -1107,7 +1118,9 @@ class GetItem2Lists(gof.op.Op):
ind1 = inp[1]
ind2 = inp[2]
out[0] = numpy.asarray(x[ind1, ind2]).flatten()
"""Here scipy returns the corresponding elements in a matrix which isn't what we are aiming for.
Using asarray and flatten, out[0] becomes an array.
"""
def grad(self, inputs, g_outputs):
x, ind1, ind2 = inputs
gout, = g_outputs
......@@ -1119,6 +1132,14 @@ class GetItem2Lists(gof.op.Op):
return self.__class__.__name__
get_item_2lists = GetItem2Lists()
"""Select elements of sparse matrix, returning them in a vector.
:param x: Sparse matrix.
:param index: List of two lists, first list indicating the row
of each element and second list indicating its column.
:return: The corresponding elements in `x`.
"""
class GetItem2ListsGrad(gof.op.Op):
......@@ -1130,7 +1151,7 @@ class GetItem2ListsGrad(gof.op.Op):
return hash(type(self))
def infer_shape(self, node, shapes):
return [(shapes[0][0], shapes[0][1])]
return [(shapes[0])]
def make_node(self, x, ind1, ind2, gz):
x = as_sparse_variable(x)
......@@ -1141,6 +1162,8 @@ class GetItem2ListsGrad(gof.op.Op):
ind2 = tensor.as_tensor_variable(ind2)
assert ind1.ndim == 1
assert ind2.ndim == 1
assert "int" in ind1.dtype
assert "int" in ind2.dtype
return gof.Apply(self, [x, ind1, ind2, gz], [x.type()])
......@@ -1155,9 +1178,8 @@ class GetItem2ListsGrad(gof.op.Op):
else:
y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1]))
z = 0
for i in range(0, len(ind1)):
for z in range(0, len(ind1)):
y[(ind1[z], ind2[z])] = gz[z]
z = z + 1
out[0] = y
......
......@@ -2054,7 +2054,10 @@ class Test_getitem(unittest.TestCase):
x, x_val = sparse_random_inputs("csr", (4,5), out_dtype="float64")
verify_grad_sparse(op_with_fixed_index, x_val)
try:
verify_grad_sparse(op_with_fixed_index, x_val)
except NotImplementedError, e:
assert "Scipy version is to old" in str(e)
def test_GetItem2Lists(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论