提交 d6a8421a authored 作者: Li Yao's avatar Li Yao

changes according to PEP8

上级 e836df32
...@@ -256,22 +256,24 @@ class EnsureSortedIndices(Op): ...@@ -256,22 +256,24 @@ class EnsureSortedIndices(Op):
""" """
def __init__(self, inplace): def __init__(self, inplace):
self.inplace=inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.view_map = {0:[0]} self.view_map = {0:[0]}
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def perform(self,node, (x,), (z,)): def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage[0]
if self.inplace: if self.inplace:
x.sort_indices() x.sort_indices()
z[0] = x z[0] = x
else: else:
z[0] = x.sorted_indices() z[0] = x.sorted_indices()
def grad(self, (x,), (gz,)): def grad(self, inputs, output_grad):
return [gz] return [output_grad[0]]
def infer_shape(self, node, i0_shapes): def infer_shape(self, node, i0_shapes):
return i0_shapes return i0_shapes
......
...@@ -413,7 +413,7 @@ def test_diagonal(): ...@@ -413,7 +413,7 @@ def test_diagonal():
assert numpy.all(n == f(range(K)).toarray()) assert numpy.all(n == f(range(K)).toarray())
def test_EnsureSortedIndices(): def test_ensure_sorted_indices():
x = 2000 x = 2000
y = 2000 y = 2000
sparsity = 1000 sparsity = 1000
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论