提交 fa34fafa authored 作者: Sina Honari's avatar Sina Honari

improving the code

上级 b27b75ce
......@@ -62,20 +62,26 @@ class SortOp(theano.Op):
self, 0, axis,
"Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3")
if isinstance(axis, theano.Constant):
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
if (isinstance(axis, theano.Constant) or (isinstance(axis, theano.tensor.TensorVariable) and axis.ndim==0)) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined(
self, 1, axis,
"sort is not defined for non-integer axes so"
" sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i):
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
index_val = theano.tensor.arange(a.shape[i]).reshape(index_shape)
return index_val
def __get_argsort_indices(self, a, axis):
"""Calculates indices which can be used to reverse
sorting operation of "a" tensor along "axis"
......@@ -94,11 +100,9 @@ class SortOp(theano.Op):
(isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx
indices = []
if axis.data >= 0:
axis_data = axis.data
else:
axis_data = a.ndim + axis.data
axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0), axis.data, a.ndim + axis.data)
for i in range(a.ndim):
<<<<<<< HEAD
if i == axis_data:
indices.append(rev_idx)
else:
......@@ -106,6 +110,10 @@ class SortOp(theano.Op):
index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
=======
index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data), rev_idx, self.__get_expanded_dim(a, axis, i))
indices.append(index_val)
>>>>>>> b4e4ae5... improving the code
return indices
"""
def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论