提交 9400253c authored 作者: Sina Honari's avatar Sina Honari

removing axis=None from sort Op

上级 13627c5c
import numpy as np
import theano
from theano.tensor import tensor
from theano.tensor.basic import mul, arange
......@@ -27,14 +24,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
......@@ -58,25 +49,12 @@ class SortOp(theano.Op):
def grad(self, inputs, output_grads):
a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented(
self, 0, axis,
"Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3")
if ((isinstance(axis, theano.Constant) or (isinstance(axis,
theano.tensor.
TensorVariable) and
axis.ndim == 0)) and
axis.data is not None):
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
"The gradient of sort is not defined "
"with respect to the integer axes itself")
return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i):
......@@ -101,9 +79,6 @@ class SortOp(theano.Op):
idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = []
axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0),
axis.data, a.ndim + axis.data)
......@@ -147,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
need to include all of the fields.
"""
if axis is None:
a = a.flatten()
axis = 0
return SortOp(kind, order)(a, axis)
......@@ -172,13 +150,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
axis = theano.Constant(theano.gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
dtype="int64", broadcastable=bcast)()])
......@@ -233,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
the same shape as a that index data along the given axis in sorted
order.
"""
if axis is None:
a = a.flatten()
axis = 0
return ArgSortOp(kind, order)(a, axis)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论