提交 78ac9eb5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 for scan_op

上级 07f7da05
...@@ -433,7 +433,7 @@ class Scan(PureOp): ...@@ -433,7 +433,7 @@ class Scan(PureOp):
aux_txt += str(k) + ',' aux_txt += str(k) + ','
aux_txt += '},%s,%s}' aux_txt += '},%s,%s}'
else: else:
aux_txt +='{%s,%s}' aux_txt += '{%s,%s}'
aux_txt = aux_txt % (name, gpu_str, str(self.name)) aux_txt = aux_txt % (name, gpu_str, str(self.name))
return aux_txt return aux_txt
...@@ -1164,16 +1164,15 @@ class Scan(PureOp): ...@@ -1164,16 +1164,15 @@ class Scan(PureOp):
### GRAD FUNCTION ### GRAD FUNCTION
def grad(self, args, g_outs): def grad(self, args, g_outs):
# This discards information about whether incoming gradients are 0 # This discards information about whether incoming gradients are 0
# or disconnected from the cost # or disconnected from the cost
# TODO: upgrade scan op to report disconnection correctly # TODO: upgrade scan op to report disconnection correctly
def strip_disconnected( g ): def strip_disconnected(g):
if isinstance(g.type, DisconnectedType): if isinstance(g.type, DisconnectedType):
return None return None
return g return g
g_outs = [ strip_disconnected(g) for g in g_outs ] g_outs = [strip_disconnected(g) for g in g_outs]
# 1. forward pass - get the outputs after applying scan # 1. forward pass - get the outputs after applying scan
scan_outputs = self(*args) scan_outputs = self(*args)
...@@ -1538,12 +1537,12 @@ class Scan(PureOp): ...@@ -1538,12 +1537,12 @@ class Scan(PureOp):
gradients += [x[::-1] for x in outputs[:end]] gradients += [x[::-1] for x in outputs[:end]]
start = len(gradients) start = len(gradients)
gradients += [ gradients += [
grad_undefined(self, x + start, args[x+start], grad_undefined(self, x + start, args[x + start],
'Shared Variable with update') 'Shared Variable with update')
for x in xrange(self.n_shared_outs)] for x in xrange(self.n_shared_outs)]
start = len(gradients) start = len(gradients)
gradients += [ gradients += [
grad_undefined(self, x + start, args[x+start], grad_undefined(self, x + start, args[x + start],
'Dimension of memory buffer for output') 'Dimension of memory buffer for output')
for x in xrange(self.n_nit_sot)] for x in xrange(self.n_nit_sot)]
begin = end begin = end
...@@ -1569,7 +1568,8 @@ class Scan(PureOp): ...@@ -1569,7 +1568,8 @@ class Scan(PureOp):
rop_self_outputs = self_outputs rop_self_outputs = self_outputs
if self.info['n_shared_outs'] > 0: if self.info['n_shared_outs'] > 0:
rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']] rop_self_outputs = rop_self_outputs[:-self.info['n_shared_outs']]
rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs, inner_eval_points) rop_outs = tensor.Rop(rop_self_outputs, rop_of_inputs,
inner_eval_points)
if type(rop_outs) not in (list, tuple): if type(rop_outs) not in (list, tuple):
rop_outs = [rop_outs] rop_outs = [rop_outs]
# Step 2. Figure out what corresponds to what in the scan # Step 2. Figure out what corresponds to what in the scan
...@@ -1760,7 +1760,7 @@ class Scan(PureOp): ...@@ -1760,7 +1760,7 @@ class Scan(PureOp):
b = e + self.n_nit_sot b = e + self.n_nit_sot
e = e + self.n_nit_sot * 2 e = e + self.n_nit_sot * 2
final_outs += outputs[b:e] final_outs += outputs[b:e]
final_outs += [None]*self.n_shared_outs final_outs += [None] * self.n_shared_outs
return final_outs return final_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论