提交 903a1183 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Some "safer" flake8 fixes.

上级 9ea09a84
...@@ -48,14 +48,6 @@ relies on the following elements to work properly : ...@@ -48,14 +48,6 @@ relies on the following elements to work properly :
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy import copy
import itertools import itertools
import logging import logging
...@@ -63,11 +55,10 @@ import time ...@@ -63,11 +55,10 @@ import time
from collections import OrderedDict from collections import OrderedDict
import numpy import numpy
from six import iteritems, integer_types from six import iteritems, integer_types, raise_from
from six.moves import xrange from six.moves import xrange
import theano import theano
from theano.compat import exc_message
from theano.compile import function, In, Out from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
...@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats ...@@ -84,6 +75,14 @@ from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
from theano.scan_module.scan_utils import safe_new, forced_replace from theano.scan_module.scan_utils import safe_new, forced_replace
__docformat__ = 'restructedtext en'
__authors__ = ("Razvan Pascanu "
"Frederic Bastien "
"James Bergstra "
"Pascal Lamblin ")
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
# Logging function for sending warning or info # Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan_op') _logger = logging.getLogger('theano.scan_module.scan_op')
...@@ -447,8 +446,8 @@ class Scan(PureOp): ...@@ -447,8 +446,8 @@ class Scan(PureOp):
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype != if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1): inner_mitmot[ipos + k].ndim != outer_mitmot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitmot), str(outer_mitmot),
...@@ -487,9 +486,9 @@ class Scan(PureOp): ...@@ -487,9 +486,9 @@ class Scan(PureOp):
new_inputs.append(outer_mitsot) new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \ if (inner_mitsots[ipos + k].type.dtype !=
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_mitsot), str(outer_mitsot),
...@@ -587,9 +586,8 @@ class Scan(PureOp): ...@@ -587,9 +586,8 @@ class Scan(PureOp):
# need to store. This input does not have the same dtype, nor is it the same # need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int. # type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs) new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip( for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs),
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)):
self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq) outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq) new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
...@@ -602,7 +600,7 @@ class Scan(PureOp): ...@@ -602,7 +600,7 @@ class Scan(PureOp):
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
# used by truncated BPTT and by scan space optimization # used by truncated BPTT and by scan space optimization
if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or if (str(outer_nitsot.type.dtype)[:3] not in ('uin', 'int') or
outer_nitsot.ndim != 0): outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot)) 'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs) assert len(new_inputs) == len(inputs)
...@@ -630,9 +628,9 @@ class Scan(PureOp): ...@@ -630,9 +628,9 @@ class Scan(PureOp):
# Check if we are dealing with same type of objects # Check if we are dealing with same type of objects
if not type(self) == type(other): if not type(self) == type(other):
return False return False
if not 'destroy_map' in self.info: if 'destroy_map' not in self.info:
self.info['destroy_map'] = OrderedDict() self.info['destroy_map'] = OrderedDict()
if not 'destroy_map' in other.info: if 'destroy_map' not in other.info:
other.info['destroy_map'] = OrderedDict() other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile', keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'n_seqs', 'tap_array',
...@@ -675,7 +673,7 @@ class Scan(PureOp): ...@@ -675,7 +673,7 @@ class Scan(PureOp):
self.destroy_map = OrderedDict() self.destroy_map = OrderedDict()
if len(self.destroy_map.keys()) > 0: if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace # Check if all outputs are inplace
if (sorted(self.destroy_map.keys()) == \ if (sorted(self.destroy_map.keys()) ==
sorted(range(self.n_mit_mot + sorted(range(self.n_mit_mot +
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot))): self.n_sit_sot))):
...@@ -840,8 +838,8 @@ class Scan(PureOp): ...@@ -840,8 +838,8 @@ class Scan(PureOp):
profile = None profile = None
if (theano.config.profile or if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, integer_types)) (isinstance(self.profile, (string_types, bool, integer_types)) and
and self.profile)): self.profile)):
if isinstance(self.profile, string_types): if isinstance(self.profile, string_types):
profile = ScanProfileStats(name=self.profile) profile = ScanProfileStats(name=self.profile)
else: else:
...@@ -1183,8 +1181,8 @@ class Scan(PureOp): ...@@ -1183,8 +1181,8 @@ class Scan(PureOp):
outs[idx][0] = args[self.seqs_arg_offset + idx] outs[idx][0] = args[self.seqs_arg_offset + idx]
elif (outs[idx][0] is not None and elif (outs[idx][0] is not None and
outs[idx][0].shape[1:] == args[self.seqs_arg_offset + outs[idx][0].shape[1:] == args[self.seqs_arg_offset +
idx].shape[1:] idx].shape[1:] and
and outs[idx][0].shape[0] >= store_steps[idx]): outs[idx][0].shape[0] >= store_steps[idx]):
# Put in the values of the initial state # Put in the values of the initial state
outs[idx][0] = outs[idx][0][:store_steps[idx]] outs[idx][0] = outs[idx][0][:store_steps[idx]]
if idx > self.n_mit_mot: if idx > self.n_mit_mot:
...@@ -1212,7 +1210,7 @@ class Scan(PureOp): ...@@ -1212,7 +1210,7 @@ class Scan(PureOp):
i = 0 i = 0
cond = True cond = True
############## THE MAIN LOOP ######################### # ############# THE MAIN LOOP ##############
# for i in xrange(n_steps): # for i in xrange(n_steps):
while (i < n_steps) and cond: while (i < n_steps) and cond:
# sequences over which scan iterates # sequences over which scan iterates
...@@ -1263,7 +1261,7 @@ class Scan(PureOp): ...@@ -1263,7 +1261,7 @@ class Scan(PureOp):
for idx in xrange(self.n_outs + self.n_nit_sot - for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot): self.n_mit_mot):
if (store_steps[idx + self.n_mit_mot] == 1 or if (store_steps[idx + self.n_mit_mot] == 1 or
self.vector_outs[idx + self.n_mit_mot]): self.vector_outs[idx + self.n_mit_mot]):
output_storage[idx + offset].storage[0] = None output_storage[idx + offset].storage[0] = None
else: else:
_pos0 = idx + self.n_mit_mot _pos0 = idx + self.n_mit_mot
...@@ -1497,7 +1495,7 @@ class Scan(PureOp): ...@@ -1497,7 +1495,7 @@ class Scan(PureOp):
end = self.n_outs + self.n_nit_sot end = self.n_outs + self.n_nit_sot
for idx in xrange(begin, end): for idx in xrange(begin, end):
if (store_steps[idx] < i - self.mintaps[idx] and if (store_steps[idx] < i - self.mintaps[idx] and
pos[idx] < store_steps[idx]): pos[idx] < store_steps[idx]):
pdx = pos[idx] pdx = pos[idx]
if pdx >= store_steps[idx] // 2: if pdx >= store_steps[idx] // 2:
...@@ -2106,7 +2104,6 @@ class Scan(PureOp): ...@@ -2106,7 +2104,6 @@ class Scan(PureOp):
dc_dxts_idx += 1 dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads) dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]: if not dC_dinps_t[dx]:
...@@ -2704,7 +2701,7 @@ class Scan(PureOp): ...@@ -2704,7 +2701,7 @@ class Scan(PureOp):
e = e + self.n_mit_sot e = e + self.n_mit_sot
ib = ie ib = ie
ie = ie + int(numpy.sum([len(x) for x in ie = ie + int(numpy.sum([len(x) for x in
self.tap_array[self.n_mit_mot: \ self.tap_array[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]])) self.n_mit_mot + self.n_mit_sot]]))
clean_eval_points = [] clean_eval_points = []
for inp, evp in zip(inputs[b:e], eval_points[b:e]): for inp, evp in zip(inputs[b:e], eval_points[b:e]):
...@@ -2827,7 +2824,7 @@ gof.ops_with_inner_function[Scan] = 'fn' ...@@ -2827,7 +2824,7 @@ gof.ops_with_inner_function[Scan] = 'fn'
# TODO: move that to the new back-end and new profiling.py print_tips # TODO: move that to the new back-end and new profiling.py print_tips
#@theano.compile.profilemode.register_profiler_printer # @theano.compile.profilemode.register_profiler_printer
def profile_printer(fct_name, compile_time, fct_call_time, fct_call, def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time, apply_cimpl, message, outputs_size, apply_time, apply_cimpl, message, outputs_size,
other_time): other_time):
...@@ -2836,9 +2833,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call, ...@@ -2836,9 +2833,9 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time.items()]): apply_time.items()]):
print() print()
print('Scan overhead:') print('Scan overhead:')
print ('<Scan op time(s)> <sub scan fct time(s)> <sub scan op ' print('<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(% scan op time)> <sub scan ' 'time(s)> <sub scan fct time(% scan op time)> <sub scan '
'op time(% scan op time)> <node>') 'op time(% scan op time)> <node>')
total_super_scan_time = 0 total_super_scan_time = 0
total_scan_fct_time = 0 total_scan_fct_time = 0
total_scan_op_time = 0 total_scan_op_time = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论