提交 cb526ca4 authored 作者: Sina Honari's avatar Sina Honari

pep8 and flake8 checking

上级 7b7d6d95
...@@ -62,11 +62,15 @@ class SortOp(theano.Op): ...@@ -62,11 +62,15 @@ class SortOp(theano.Op):
self, 0, axis, self, 0, axis,
"Currently, we only implement the gradient on sort for vector" "Currently, we only implement the gradient on sort for vector"
" matrix (and axis is None or 0) and tensor3") " matrix (and axis is None or 0) and tensor3")
if (isinstance(axis, theano.Constant) or (isinstance(axis, theano.tensor.TensorVariable) and axis.ndim==0)) and axis.data is not None: if ((isinstance(axis, theano.Constant) or (isinstance(axis,
theano.tensor.
TensorVariable)
and axis.ndim == 0)) and
axis.data is not None):
indices = self.__get_argsort_indices(a, axis) indices = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][tuple(indices)] inp_grad = output_grads[0][tuple(indices)]
elif (axis is None or elif (axis is None or
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
rev_idx = self.__get_argsort_indices(a, axis) rev_idx = self.__get_argsort_indices(a, axis)
inp_grad = output_grads[0][rev_idx].reshape(a.shape) inp_grad = output_grads[0][rev_idx].reshape(a.shape)
axis_grad = theano.gradient.grad_undefined( axis_grad = theano.gradient.grad_undefined(
...@@ -78,8 +82,9 @@ class SortOp(theano.Op): ...@@ -78,8 +82,9 @@ class SortOp(theano.Op):
def __get_expanded_dim(self, a, axis, i): def __get_expanded_dim(self, a, axis, i):
index_shape = [1] * a.ndim index_shape = [1] * a.ndim
index_shape[i] = a.shape[i] index_shape[i] = a.shape[i]
# it's a way to emulate numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]] # it's a way to emulate
index_val = theano.tensor.arange(a.shape[i]).reshape(index_shape) # numpy.ogrid[0: a.shape[0], 0: a.shape[1], 0: a.shape[2]]
index_val = arange(a.shape[i]).reshape(index_shape)
return index_val return index_val
def __get_argsort_indices(self, a, axis): def __get_argsort_indices(self, a, axis):
...@@ -100,9 +105,13 @@ class SortOp(theano.Op): ...@@ -100,9 +105,13 @@ class SortOp(theano.Op):
(isinstance(axis, theano.Constant) and axis.data is None)): (isinstance(axis, theano.Constant) and axis.data is None)):
return rev_idx return rev_idx
indices = [] indices = []
axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0), axis.data, a.ndim + axis.data) axis_data = theano.tensor.switch(theano.tensor.ge(axis.data, 0),
axis.data, a.ndim + axis.data)
for i in range(a.ndim): for i in range(a.ndim):
index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data), rev_idx, self.__get_expanded_dim(a, axis, i)) index_val = theano.tensor.switch(theano.tensor.eq(i, axis_data),
rev_idx,
self.__get_expanded_dim(a,
axis, i))
indices.append(index_val) indices.append(index_val)
return indices return indices
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论