提交 e9842115 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

pep8 compatibility

上级 c4c6f139
...@@ -361,9 +361,9 @@ def scan(fn, ...@@ -361,9 +361,9 @@ def scan(fn,
if arg_info.get('taps', None) == [-1]: if arg_info.get('taps', None) == [-1]:
mintaps.append(1) mintaps.append(1)
lengths.append(scalar_shared(numpy.int64(0), lengths.append(scalar_shared(numpy.int64(0),
name = 'l%d' % pos)) name='l%d' % pos))
arg_info['initial'] = scan_utils.expand(tensor.unbroadcast( arg_info['initial'] = scan_utils.expand(tensor.unbroadcast(
tensor.shape_padleft(arg_info['initial']),0), T) tensor.shape_padleft(arg_info['initial']), 0), T)
elif arg_info.get('taps', None): elif arg_info.get('taps', None):
if numpy.any(numpy.array(arg_info.get('taps', [])) > 0): if numpy.any(numpy.array(arg_info.get('taps', [])) > 0):
# Make sure we do not have requests for future values of a # Make sure we do not have requests for future values of a
...@@ -372,14 +372,14 @@ def scan(fn, ...@@ -372,14 +372,14 @@ def scan(fn,
arg_info) arg_info)
mintap = abs(numpy.min(arg_info['taps'])) mintap = abs(numpy.min(arg_info['taps']))
lengths.append(scalar_shared(numpy.int64(0), lengths.append(scalar_shared(numpy.int64(0),
name = 'l%d' % pos)) name='l%d' % pos))
mintaps.append(mintap) mintaps.append(mintap)
arg_info['initial'] = scan_utils.expand( arg_info['initial'] = scan_utils.expand(
arg_info['initial'][:mintap], T) arg_info['initial'][:mintap], T)
else: else:
mintaps.append(0) mintaps.append(0)
lengths.append(scalar_shared(numpy.int64(0), lengths.append(scalar_shared(numpy.int64(0),
name = 'l%d' % pos)) name='l%d' % pos))
# 3. Generate arguments for the function passed to scan. This will # 3. Generate arguments for the function passed to scan. This will
# function will return the outputs that need to be computed at every # function will return the outputs that need to be computed at every
...@@ -509,7 +509,7 @@ def scan(fn, ...@@ -509,7 +509,7 @@ def scan(fn,
scan_outputs = [] scan_outputs = []
for pos in xrange(len(states_and_outputs)): for pos in xrange(len(states_and_outputs)):
out = scan_utils.ScanPermutation(mintaps[pos])( out = scan_utils.ScanPermutation(mintaps[pos])(
scan_outputs_update_rules[pos],t) scan_outputs_update_rules[pos], t)
scan_outputs.append(out[mintap:]) scan_outputs.append(out[mintap:])
# 5.6 Construct updates dictionary # 5.6 Construct updates dictionary
update_rules = scan_outputs_update_rules[len(states_and_outputs):] update_rules = scan_outputs_update_rules[len(states_and_outputs):]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论