提交 9400253c authored 作者: Sina Honari's avatar Sina Honari

removing axis=None from sort Op

上级 13627c5c
import numpy as np import numpy as np
import theano import theano
from theano.tensor import tensor
from theano.tensor.basic import mul, arange from theano.tensor.basic import mul, arange
...@@ -27,14 +24,8 @@ class SortOp(theano.Op): ...@@ -27,14 +24,8 @@ class SortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or axis = theano.tensor.as_tensor_variable(axis)
(isinstance(axis, theano.Constant) and axis.data is None)): out_type = input.type()
axis = theano.Constant(theano.gof.generic, None)
# axis=None flattens the array before sorting
out_type = tensor(dtype=input.dtype, broadcastable=[False])
else:
axis = theano.tensor.as_tensor_variable(axis)
out_type = input.type()
return theano.Apply(self, [input, axis], [out_type]) return theano.Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
...@@ -58,25 +49,12 @@ class SortOp(theano.Op): ...@@ -58,25 +49,12 @@ class SortOp(theano.Op):
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
a, axis = inputs a, axis = inputs
inp_grad = theano.gradient.grad_not_implemented( indices = self.__get_argsort_indices(a, axis)
self, 0, axis, inp_grad = output_grads[0][tuple(indices)]
"Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3")
if ((isinstance(axis, theano.Constant) or (isinstance(axis,
theano.tensor.
TensorVariable) and
axis.ndim == 0)) and
axis.data is not None):
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
self, 1, axis, self, 1, axis,
"sort is not defined for non-integer axes so" "The gradient of sort is not defined "
" sort(x, axis+eps) is undefined") "with respect to the integer axes itself")
return [inp_grad, axis_grad] return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i): def __get_expanded_dim(self, a, axis, i):
...@@ -101,9 +79,6 @@ class SortOp(theano.Op): ...@@ -101,9 +79,6 @@ class SortOp(theano.Op):
idx = argsort(a, axis, kind=self.kind, order=self.order) idx = argsort(a, axis, kind=self.kind, order=self.order)
# rev_idx is the reverse of previous argsort operation # rev_idx is the reverse of previous argsort operation
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order) rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
if (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = [] indices = []
axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0), axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0),
axis.data, a.ndim + axis.data) axis.data, a.ndim + axis.data)
...@@ -147,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None): ...@@ -147,6 +122,9 @@ def sort(a, axis=-1, kind='quicksort', order=None):
need to include all of the fields. need to include all of the fields.
""" """
if axis is None:
a = a.flatten()
axis = 0
return SortOp(kind, order)(a, axis) return SortOp(kind, order)(a, axis)
...@@ -172,13 +150,8 @@ class ArgSortOp(theano.Op): ...@@ -172,13 +150,8 @@ class ArgSortOp(theano.Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = theano.tensor.as_tensor_variable(input) input = theano.tensor.as_tensor_variable(input)
if (axis is None or axis = theano.tensor.as_tensor_variable(axis)
(isinstance(axis, theano.Constant) and axis.data is None)): bcast = input.type.broadcastable
axis = theano.Constant(theano.gof.generic, None)
bcast = [False]
else:
axis = theano.tensor.as_tensor_variable(axis)
bcast = input.type.broadcastable
return theano.Apply(self, [input, axis], [theano.tensor.TensorType( return theano.Apply(self, [input, axis], [theano.tensor.TensorType(
dtype="int64", broadcastable=bcast)()]) dtype="int64", broadcastable=bcast)()])
...@@ -233,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None): ...@@ -233,4 +206,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
the same shape as a that index data along the given axis in sorted the same shape as a that index data along the given axis in sorted
order. order.
""" """
if axis is None:
a = a.flatten()
axis = 0
return ArgSortOp(kind, order)(a, axis) return ArgSortOp(kind, order)(a, axis)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论