提交 2c125069 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 tensor/subtensor.py

上级 3c5b4282
from copy import copy
import os
import sys
from textwrap import dedent
import warnings
import logging
_logger = logging.getLogger("theano.tensor.subtensor")
import numpy
from six.moves import xrange
......@@ -32,6 +30,7 @@ if config.cxx:
except ImportError:
pass
_logger = logging.getLogger("theano.tensor.subtensor")
# Do a lazy import of the sparse module
sparse_module_ref = None
......@@ -336,9 +335,9 @@ class Subtensor(Op):
theano.tensor.wscalar, theano.tensor.bscalar]
invalid_tensor_types = [theano.tensor.fscalar, theano.tensor.dscalar,
theano.tensor.cscalar, theano.tensor.zscalar]
if (isinstance(entry, gof.Variable)
and (entry.type in invalid_scal_types
or entry.type in invalid_tensor_types)):
if (isinstance(entry, gof.Variable) and
(entry.type in invalid_scal_types or
entry.type in invalid_tensor_types)):
raise TypeError("Expected an integer")
if isinstance(entry, gof.Variable) and entry.type in scal_types:
......@@ -346,13 +345,13 @@ class Subtensor(Op):
elif isinstance(entry, gof.Type) and entry in scal_types:
return entry
if (isinstance(entry, gof.Variable)
and entry.type in tensor_types
and numpy.all(entry.type.broadcastable)):
if (isinstance(entry, gof.Variable) and
entry.type in tensor_types and
numpy.all(entry.type.broadcastable)):
return scal.get_scalar_type(entry.type.dtype)
elif (isinstance(entry, gof.Type)
and entry in tensor_types
and numpy.all(entry.broadcastable)):
elif (isinstance(entry, gof.Type) and
entry in tensor_types and
numpy.all(entry.broadcastable)):
return scal.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
......@@ -425,8 +424,9 @@ class Subtensor(Op):
conv(val.step))
else:
try:
return get_scalar_constant_value(val,
only_process_constants=only_process_constants)
return get_scalar_constant_value(
val,
only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError:
if allow_partial:
return val
......@@ -477,8 +477,8 @@ class Subtensor(Op):
% (input.type, expected_type))
# infer the broadcasting pattern
padded = (self.get_constant_idx((None,)+inputs, allow_partial=True)
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
padded = (self.get_constant_idx((None,) + inputs, allow_partial=True) +
[slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice):
......@@ -528,9 +528,9 @@ class Subtensor(Op):
if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if ((idx.start in [None, 0])
and (idx.stop in [None, sys.maxsize])
and (idx.step is None or idx.step == 1)):
if ((idx.start in [None, 0]) and
(idx.stop in [None, sys.maxsize]) and
(idx.step is None or idx.step == 1)):
outshp.append(xl)
else:
cnf = get_canonical_form_slice(idx, xl)[0]
......@@ -556,8 +556,7 @@ class Subtensor(Op):
first = x.zeros_like().astype(theano.config.floatX)
else:
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return ([first]
+ [DisconnectedType()()] * len(rest))
return ([first] + [DisconnectedType()()] * len(rest))
def connection_pattern(self, node):
......@@ -1034,8 +1033,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
dim_offset = x.ndim - y.ndim
for dim in xrange(y.ndim):
if (x.broadcastable[dim + dim_offset]
and not y.broadcastable[dim]):
if (x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]):
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
......@@ -2133,9 +2131,9 @@ class AdvancedIncSubtensor(Op):
return hash((type(self), self.inplace, self.set_instead_of_inc))
def __eq__(self, other):
return (type(self) == type(other)
and self.inplace == other.inplace
and self.set_instead_of_inc == other.set_instead_of_inc)
return (type(self) == type(other) and
self.inplace == other.inplace and
self.set_instead_of_inc == other.set_instead_of_inc)
def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__,
......
......@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py",
"tensor/subtensor.py",
"tensor/elemwise.py",
"tensor/xlogx.py",
"tensor/blas_headers.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论