提交 816d0bbc authored 作者: Frederic's avatar Frederic

pep8

上级 0a3c13f7
......@@ -10,6 +10,7 @@ from theano import gof, scalar
from theano.gradient import DisconnectedType
tensor = basic
class CumsumOp(theano.Op):
# See function cumsum for docstring
def __init__(self, axis=None):
......@@ -170,10 +171,11 @@ class CumprodOp(theano.Op):
# We need to reverse the gradients along ``self.axis``,
# compute cumsum, then reverse again
reverse_slicing = [slice(None,None,None)] * gi.ndim
reverse_slicing[self.axis] = slice(None,None,-1)
reverse_slicing = [slice(None, None, None)] * gi.ndim
reverse_slicing[self.axis] = slice(None, None, -1)
reverse_slicing = tuple(reverse_slicing)
return [cumsum((fx * gi)[reverse_slicing], self.axis)[reverse_slicing] / x]
return [cumsum((fx * gi)[reverse_slicing],
self.axis)[reverse_slicing] / x]
def infer_shape(self, node, shapes):
if self.axis is None:
......@@ -845,18 +847,17 @@ class FillDiagonalOffset(gof.Op):
neg_offset_flag = basic.lt(offset, 0)
min_wh = basic.minimum(width, height)
start = offset * pos_offset_flag + offset_abs * width \
* neg_offset_flag
num_of_step = basic.minimum( min_wh, width * pos_offset_flag
+ height * neg_offset_flag - offset_abs )
start = offset * pos_offset_flag + offset_abs * width * neg_offset_flag
num_of_step = basic.minimum(min_wh, width * pos_offset_flag +
height * neg_offset_flag - offset_abs)
step = a.shape[1] + 1
end = start + step * num_of_step
# input of slice should be integer
start = basic.cast(start,'int32')
step = basic.cast(step,'int32')
end = basic.cast(end,'int32')
start = basic.cast(start, 'int32')
step = basic.cast(step, 'int32')
end = basic.cast(end, 'int32')
wr_val = grad.flatten()[start:end:step].sum()
......@@ -865,10 +866,11 @@ class FillDiagonalOffset(gof.Op):
"offset is not defined for non-integer offset so"
" fill_diagonal_offset(a,val,offset+eps) is undefined")
return [wr_a, wr_val,wr_offset]
return [wr_a, wr_val, wr_offset]
fill_diagonal_offset_ = FillDiagonalOffset()
def fill_diagonal_offset(a, val, offset):
"""
Returns a copy of an array with all
......@@ -884,4 +886,3 @@ def fill_diagonal_offset(a, val, offset):
is filled with scalar 'val'. The output is unwrapped.
"""
return fill_diagonal_offset_(a, val, offset)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论