提交 bf4954f9 authored 作者: Tanjay94's avatar Tanjay94

Fixed GetItemList grad to support old scipy version.

上级 25d88718
...@@ -1071,10 +1071,11 @@ class GetItemListGrad(gof.op.Op): ...@@ -1071,10 +1071,11 @@ class GetItemListGrad(gof.op.Op):
y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1])) y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1]))
else: else:
y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1])) y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1]))
z = 0 c = 0
for i in indices: for a in range(0, gz.shape[0]):
y[i:i+1] = gz[z:z+1] for b in range(0, gz.shape[1]):
z = z+1 y[(indices[c], b)] = gz[(a, b)]
c = c + 1
out[0] = y out[0] = y
...@@ -1109,12 +1110,67 @@ class GetItem2Lists(gof.op.Op): ...@@ -1109,12 +1110,67 @@ class GetItem2Lists(gof.op.Op):
p.append(x[(ind1[i],ind2[i])]) p.append(x[(ind1[i],ind2[i])])
out[0] = numpy.asarray(p) out[0] = numpy.asarray(p)
def grad(self, inputs, g_outputs):
x, ind1, ind2 = inputs
gout, = g_outputs
return [GetItem2ListsGrad(self)(x, ind1, ind2, gout),
grad_undefined(self, 1, ind1, "No gradient for this input"),
grad_undefined(self, 1, ind2, "No gradient for this input")]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
get_item_2lists = GetItem2Lists() get_item_2lists = GetItem2Lists()
class GetItem2ListsGrad(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def infer_shape(self, node, shapes):
return [(shapes[0][0], shapes[0][1])]
def make_node(self, x, ind1, ind2, gz):
x = as_sparse_variable(x)
gz = as_sparse_variable(gz)
assert x.format in ["csr", "csc"]
assert gz.format in ["csr", "csc"]
ind1 = tensor.as_tensor_variable(index)
ind2 = tensor.as_tensor_variable(index)
assert ind1.ndim == 1
assert ind2.ndim == 1
return gof.Apply(self, [x, ind1, ind2, gz], [x.type()])
def perform(self, node, inp, (out, )):
x = inp[0]
ind1 = inp[1]
ind2 = inp[2]
gz = inp[3]
if x.format in ["csr"]:
y = scipy.sparse.csr_matrix((x.shape[0], x.shape[1]))
else:
y = scipy.sparse.csc_matrix((x.shape[0], x.shape[1]))
z = 0
for i in range(0, len(ind1)):
y[(ind1[z], ind2[z])] = gz[z]
z = z + 1
out[0] = y
def __str__(self):
return self.__class__.__name__
get_item_2lists_grad = GetItem2ListsGrad()
class GetItem2d(gof.op.Op): class GetItem2d(gof.op.Op):
# See doc in instance of this Op or function after this class definition. # See doc in instance of this Op or function after this class definition.
def __eq__(self, other): def __eq__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论