提交 3cacf1f7 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/sort.py

上级 d40861ec
......@@ -64,7 +64,7 @@ class SortOp(theano.Op):
" matrix (and axis is None or 0) and tensor3")
if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
......@@ -72,8 +72,9 @@ class SortOp(theano.Op):
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0]*a.shape[1])[:, None]).nonzero()[1]
rev_idx = theano.tensor.eq(
idx[None, :],
arange(a.shape[0] * a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)):
......@@ -178,8 +179,8 @@ class ArgSortOp(theano.Op):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self):
return (self.__class__.__name__
+ "{%s, %s}" % (self.kind, str(self.order)))
return (self.__class__.__name__ +
"{%s, %s}" % (self.kind, str(self.order)))
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
......@@ -190,15 +191,14 @@ class ArgSortOp(theano.Op):
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis],
[theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()])
return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
dtype="int64", broadcastable=bcast)()])
def perform(self, node, inputs, output_storage):
a = inputs[0]
axis = inputs[1]
z = output_storage[0]
z[0] = theano._asarray(
np.argsort(a, axis, self.kind, self.order),
z[0] = theano._asarray(np.argsort(a, axis, self.kind, self.order),
dtype=node.outputs[0].dtype)
def infer_shape(self, node, inputs_shapes):
......
......@@ -60,7 +60,6 @@ whitelist_flake8 = [
"tensor/blas_headers.py",
"tensor/type.py",
"tensor/fourier.py",
"tensor/sort.py",
"tensor/__init__.py",
"tensor/opt_uncanonicalize.py",
"tensor/blas.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论