提交 b27b75ce authored 作者: Sina Honari's avatar Sina Honari

improvement

上级 18d155df
...@@ -62,13 +62,7 @@ class SortOp(theano.Op): ...@@ -62,13 +62,7 @@ class SortOp(theano.Op):
self, 0, axis, self, 0, axis,
"Currently, we only implement the gradient on sort for vector" "Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3") " matrix (and axis is None or 0) and tensor3")
if a.ndim == 1: if isinstance(axis, theano.Constant):
idx = argsort(*inputs, kind=self.kind, order=self.order)
# rev_idx = numpy.where(idx[None, :]==numpy.arange(5)[:,None])[1]
rev_idx = theano.tensor.eq(idx[None, :],
arange(a.shape[0])[:, None]).nonzero()[1]
inp_grad = output_grads[0][rev_idx]
elif isinstance(axis, theano.Constant):
if isinstance(axis, theano.Constant) and axis.data is not None: if isinstance(axis, theano.Constant) and axis.data is not None:
indices = self.__get_argsort_indices(a, axis) indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)] inp_grad = output_grads[0][tuple(indices)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论