提交 fa34fafa authored 作者: Sina Honari's avatar Sina Honari

improving the code

上级 b27b75ce
...@@ -62,8 +62,7 @@ class SortOp(theano.Op): ...@@ -62,8 +62,7 @@ class SortOp(theano.Op):
self, 0, axis, self, 0, axis,
"Currently, we only implement the gradient on sort for vector" "Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3") " matrix (and axis is None or 0) and tensor3")
if isinstance(axis, theano.Constant): if (isinstance(axis, theano.Constant) or (isinstance(axis, theano.tensor.TensorVariable) and axis.ndim==0)) and axis.data is not None:
if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis) indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)] inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or elif (axis is None or
...@@ -76,6 +75,13 @@ class SortOp(theano.Op): ...@@ -76,6 +75,13 @@ class SortOp(theano.Op):
" sort(x, axis+eps) is undefined") " sort(x, axis+eps) is undefined")
return [inp_grad, axis_grad] return [inp_grad, axis_grad]
def __get_expanded_dim(self, a, axis, i):
index_shape = [1] * a.ndim
index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
index_val = theano.tensor.arange(a.shape[i]).reshape(index_shape)
return index_val
def __get_argsort_indices(self, a, axis): def __get_argsort_indices(self, a, axis):
"""Calculates indices which can be used to reverse """Calculates indices which can be used to reverse
sorting operation of "a" tensor along "axis" sorting operation of "a" tensor along "axis"
...@@ -94,11 +100,9 @@ class SortOp(theano.Op): ...@@ -94,11 +100,9 @@ class SortOp(theano.Op):
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx return rev_idx
indices = [] indices = []
if axis.data >= 0: axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0), axis.data, a.ndim + axis.data)
axis_data = axis.data
else:
axis_data = a.ndim + axis.data
for i in range(a.ndim): for i in range(a.ndim):
<<<<<<< HEAD
if i == axis_data: if i == axis_data:
indices.append(rev_idx) indices.append(rev_idx)
else: else:
...@@ -106,6 +110,10 @@ class SortOp(theano.Op): ...@@ -106,6 +110,10 @@ class SortOp(theano.Op):
index_shape[i] = a.shape[i] index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]] # it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape)) indices.append(theano.tensor.arange(a.shape[i]).reshape(index_shape))
=======
index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data), rev_idx, self.__get_expanded_dim(a, axis, i))
indices.append(index_val)
>>>>>>> b4e4ae5... improving the code
return indices return indices
""" """
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论