提交 3cacf1f7 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for tensor/sort.py

上级 d40861ec
...@@ -28,7 +28,7 @@ class SortOp(theano.Op): ...@@ -28,7 +28,7 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None) axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting # axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False]) out_type = tensor(dtype=input.dtype, broadcastable=[False])
...@@ -45,7 +45,7 @@ class SortOp(theano.Op): ...@@ -45,7 +45,7 @@ class SortOp(theano.Op):
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], theano.Constant) and if (isinstance(node.inputs[1], theano.Constant) and
node.inputs[1].data is None): node.inputs[1].data is None):
# That means axis = None, # That means axis = None,
# So the array is flattened before being sorted # So the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)] return [(mul(*inputs_shapes[0]),)]
...@@ -64,16 +64,17 @@ class SortOp(theano.Op): ...@@ -64,16 +64,17 @@ class SortOp(theano.Op):
" matrix (and axis is None or 0) and tensor3") " matrix (and axis is None or 0) and tensor3")
if a.ndim == 1: if a.ndim == 1:
idx = argsort(*inputs, kind=self.kind, order=self.order) idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1] # rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :], rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1] arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx] inp_grad = output_grads[0][rev_idx]
elif a.ndim == 2: elif a.ndim == 2:
if (axis is None or if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
idx = argsort(*inputs, kind=self.kind, order=self.order) idx = argsort(*inputs, kind=self.kind, order=self.order)
rev_idx = theano.tensor.eq(idx[None, :], rev_idx = theano.tensor.eq(
arange(a.shape[0]*a.shape[1])[:, None]).nonzero()[1] idx[None, :],
arange(a.shape[0] * a.shape[1])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx].reshape(a.shape) inp_grad = output_grads[0][rev_idx].reshape(a.shape)
elif (axis == 0 or elif (axis == 0 or
(isinstance(axis, theano.Constant) and axis.data == 0)): (isinstance(axis, theano.Constant) and axis.data == 0)):
...@@ -85,7 +86,7 @@ class SortOp(theano.Op): ...@@ -85,7 +86,7 @@ class SortOp(theano.Op):
indices = self.__get_argsort_indices(a, axis) indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][indices[0], indices[1], indices[2]] inp_grad = output_grads[0][indices[0], indices[1], indices[2]]
elif (axis is None or elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis) rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape) inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
...@@ -103,13 +104,13 @@ class SortOp(theano.Op): ...@@ -103,13 +104,13 @@ class SortOp(theano.Op):
list of lenght len(a.shape) otherwise list of lenght len(a.shape) otherwise
""" """
# The goal is to get gradient wrt input from gradient # The goal is to get gradient wrt input from gradient
# wrt sort(input, axis) # wrt sort(input, axis)
idx = argsort(a, axis, kind=self.kind, order=self.order) idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation # rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx return rev_idx
indices = [] indices = []
if axis.data >= 0: if axis.data >= 0:
...@@ -120,7 +121,7 @@ class SortOp(theano.Op): ...@@ -120,7 +121,7 @@ class SortOp(theano.Op):
if i == axis_data: if i == axis_data:
indices.append(rev_idx) indices.append(rev_idx)
else: else:
index_shape = [1] * a.ndim index_shape = [1] * a.ndim
index_shape[i] = a.shape[i] index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]] # it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape)) indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
...@@ -178,28 +179,27 @@ class ArgSortOp(theano.Op): ...@@ -178,28 +179,27 @@ class ArgSortOp(theano.Op):
return hash(type(self)) ^ hash(self.order) ^ hash(self.kind) return hash(type(self)) ^ hash(self.order) ^ hash(self.kind)
def __str__(self): def __str__(self):
return (self.__class__.__name__ return (self.__class__.__name__ +
+ "{%s, %s}" % (self.kind, str(self.order))) "{%s, %s}" % (self.kind, str(self.order)))
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None) axis = theano.Constant(theano.gof.generic, None)
bcast = [False] bcast = [False]
else: else:
axis = theano.tensor.as_tensor_variable(axis) axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
[theano.tensor.TensorType(dtype="int64", broadcastable=bcast)()]) dtype="int64", broadcastable=bcast)()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
a = inputs[0] a = inputs[0]
axis = inputs[1] axis = inputs[1]
z = output_storage[0] z = output_storage[0]
z[0] = theano._asarray( z[0] = theano._asarray(np.argsort(a, axis, self.kind, self.order),
np.argsort(a, axis, self.kind, self.order), dtype=node.outputs[0].dtype)
dtype=node.outputs[0].dtype)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], theano.Constant) and if (isinstance(node.inputs[1], theano.Constant) and
......
...@@ -60,7 +60,6 @@ whitelist_flake8 = [ ...@@ -60,7 +60,6 @@ whitelist_flake8 = [
"tensor/blas_headers.py", "tensor/blas_headers.py",
"tensor/type.py", "tensor/type.py",
"tensor/fourier.py", "tensor/fourier.py",
"tensor/sort.py",
"tensor/__init__.py", "tensor/__init__.py",
"tensor/opt_uncanonicalize.py", "tensor/opt_uncanonicalize.py",
"tensor/blas.py", "tensor/blas.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论