提交 903a1183 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Some "safer" flake8 fixes.

上级 9ea09a84
......@@ -48,14 +48,6 @@ relies on the following elements to work properly :
"""
from __future__ import absolute_import, print_function, division
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy
import itertools
import logging
......@@ -63,11 +55,10 @@ import time
from collections import OrderedDict
import numpy
from six import iteritems, integer_types
from six import iteritems, integer_types, raise_from
from six.moves import xrange
import theano
from theano.compat import exc_message
from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor
......@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import safe_new, forced_replace
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op')
......@@ -447,8 +446,8 @@ class Scan(PureOp):
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitmot),
......@@ -487,9 +486,9 @@ class Scan(PureOp):
new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
if (inner_mitsots[ipos + k].type.dtype !=
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_mitsot),
......@@ -587,9 +586,8 @@ class Scan(PureOp):
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type:
......@@ -602,7 +600,7 @@ class Scan(PureOp):
# depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization
if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
outer_nitsot.ndim != 0):
outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs)
......@@ -630,9 +628,9 @@ class Scan(PureOp):
# Check if we are dealing with same type of objects
if not type(self) == type(other):
return False
if not 'destroy_map' in self.info:
if 'destroy_map' not in self.info:
self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info:
if 'destroy_map' not in other.info:
other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array',
......@@ -675,7 +673,7 @@ class Scan(PureOp):
self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \
if (sorted(self.destroy_map.keys()) ==
sorted(range(self.n_mit_mot +
self.n_mit_sot +
self.n_sit_sot))):
......@@ -840,8 +838,8 @@ class Scan(PureOp):
profile = None
if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types))
and self.profile)):
(isinstance(self.profile, (string_types, bool, integer_types)) and
self.profile)):
if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile)
else:
......@@ -1183,8 +1181,8 @@ class Scan(PureOp):
outs[idx][0] = args[self.seqs_arg_offset + idx]
elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
idx].shape[1:]
and outs[idx][0].shape[0] >= store_steps[idx]):
idx].shape[1:] and
outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot:
......@@ -1212,7 +1210,7 @@ class Scan(PureOp):
i = 0
cond = True
############## THE MAIN LOOP #########################
# ############# THE MAIN LOOP ##############
# for i in xrange(n_steps):
while (i < n_steps) and cond:
# sequences over which scan iterates
......@@ -1263,7 +1261,7 @@ class Scan(PureOp):
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx + self.n_mit_mot]):
self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx + offset].storage[0] = None
else:
_pos0 = idx + self.n_mit_mot
......@@ -1497,7 +1495,7 @@ class Scan(PureOp):
end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end):
if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]):
pos[idx] < store_steps[idx]):
pdx = pos[idx]
if pdx >= store_steps[idx] // 2:
......@@ -2106,7 +2104,6 @@ class Scan(PureOp):
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
......@@ -2704,7 +2701,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot
ib = ie
ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot: \
self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]):
......@@ -2827,7 +2824,7 @@ gof.ops_with_inner_function[Scan] = 'fn'
# TODO: move that to the new back-end and new profiling.py print_tips
#@theano.compile.profilemode.register_profiler_printer
# @theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size,
other_time):
......@@ -2836,9 +2833,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time.items()]):
print()
print('Scan overhead:')
print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>')
total_super_scan_time = 0
total_scan_fct_time = 0
total_scan_op_time = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论