提交 2c125069 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 tensor/subtensor.py

上级 3c5b4282
from copy import copy from copy import copy
import os
import sys import sys
from textwrap import dedent from textwrap import dedent
import warnings import warnings
import logging import logging
_logger = logging.getLogger("theano.tensor.subtensor")
import numpy import numpy
from six.moves import xrange from six.moves import xrange
...@@ -32,6 +30,7 @@ if config.cxx: ...@@ -32,6 +30,7 @@ if config.cxx:
except ImportError: except ImportError:
pass pass
_logger = logging.getLogger("theano.tensor.subtensor")
# Do a lazy import of the sparse module # Do a lazy import of the sparse module
sparse_module_ref = None sparse_module_ref = None
...@@ -336,9 +335,9 @@ class Subtensor(Op): ...@@ -336,9 +335,9 @@ class Subtensor(Op):
theano.tensor.wscalar, theano.tensor.bscalar] theano.tensor.wscalar, theano.tensor.bscalar]
invalid_tensor_types = [theano.tensor.fscalar, theano.tensor.dscalar, invalid_tensor_types = [theano.tensor.fscalar, theano.tensor.dscalar,
theano.tensor.cscalar, theano.tensor.zscalar] theano.tensor.cscalar, theano.tensor.zscalar]
if (isinstance(entry, gof.Variable) if (isinstance(entry, gof.Variable) and
and (entry.type in invalid_scal_types (entry.type in invalid_scal_types or
or entry.type in invalid_tensor_types)): entry.type in invalid_tensor_types)):
raise TypeError("Expected an integer") raise TypeError("Expected an integer")
if isinstance(entry, gof.Variable) and entry.type in scal_types: if isinstance(entry, gof.Variable) and entry.type in scal_types:
...@@ -346,13 +345,13 @@ class Subtensor(Op): ...@@ -346,13 +345,13 @@ class Subtensor(Op):
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, gof.Type) and entry in scal_types:
return entry return entry
if (isinstance(entry, gof.Variable) if (isinstance(entry, gof.Variable) and
and entry.type in tensor_types entry.type in tensor_types and
and numpy.all(entry.type.broadcastable)): numpy.all(entry.type.broadcastable)):
return scal.get_scalar_type(entry.type.dtype) return scal.get_scalar_type(entry.type.dtype)
elif (isinstance(entry, gof.Type) elif (isinstance(entry, gof.Type) and
and entry in tensor_types entry in tensor_types and
and numpy.all(entry.broadcastable)): numpy.all(entry.broadcastable)):
return scal.get_scalar_type(entry.dtype) return scal.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
...@@ -425,7 +424,8 @@ class Subtensor(Op): ...@@ -425,7 +424,8 @@ class Subtensor(Op):
conv(val.step)) conv(val.step))
else: else:
try: try:
return get_scalar_constant_value(val, return get_scalar_constant_value(
val,
only_process_constants=only_process_constants) only_process_constants=only_process_constants)
except theano.tensor.NotScalarConstantError: except theano.tensor.NotScalarConstantError:
if allow_partial: if allow_partial:
...@@ -477,8 +477,8 @@ class Subtensor(Op): ...@@ -477,8 +477,8 @@ class Subtensor(Op):
% (input.type, expected_type)) % (input.type, expected_type))
# infer the broadcasting pattern # infer the broadcasting pattern
padded = (self.get_constant_idx((None,)+inputs, allow_partial=True) padded = (self.get_constant_idx((None,) + inputs, allow_partial=True) +
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list))) [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = [] broadcastable = []
for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)): for i, (p, bc) in enumerate(izip(padded, x.type.broadcastable)):
if isinstance(p, slice): if isinstance(p, slice):
...@@ -528,9 +528,9 @@ class Subtensor(Op): ...@@ -528,9 +528,9 @@ class Subtensor(Op):
if isinstance(idx, slice): if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant, # If it is the default (None, None, None) slice, or a variant,
# the shape will be xl # the shape will be xl
if ((idx.start in [None, 0]) if ((idx.start in [None, 0]) and
and (idx.stop in [None, sys.maxsize]) (idx.stop in [None, sys.maxsize]) and
and (idx.step is None or idx.step == 1)): (idx.step is None or idx.step == 1)):
outshp.append(xl) outshp.append(xl)
else: else:
cnf = get_canonical_form_slice(idx, xl)[0] cnf = get_canonical_form_slice(idx, xl)[0]
...@@ -556,8 +556,7 @@ class Subtensor(Op): ...@@ -556,8 +556,7 @@ class Subtensor(Op):
first = x.zeros_like().astype(theano.config.floatX) first = x.zeros_like().astype(theano.config.floatX)
else: else:
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest) first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return ([first] return ([first] + [DisconnectedType()()] * len(rest))
+ [DisconnectedType()()] * len(rest))
def connection_pattern(self, node): def connection_pattern(self, node):
...@@ -1034,8 +1033,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -1034,8 +1033,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
dim_offset = x.ndim - y.ndim dim_offset = x.ndim - y.ndim
for dim in xrange(y.ndim): for dim in xrange(y.ndim):
if (x.broadcastable[dim + dim_offset] if (x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]):
and not y.broadcastable[dim]):
# It is acceptable to try to increment a subtensor with a # It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable # broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1. # on that dimension. However, its length must then be 1.
...@@ -2133,9 +2131,9 @@ class AdvancedIncSubtensor(Op): ...@@ -2133,9 +2131,9 @@ class AdvancedIncSubtensor(Op):
return hash((type(self), self.inplace, self.set_instead_of_inc)) return hash((type(self), self.inplace, self.set_instead_of_inc))
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) return (type(self) == type(other) and
and self.inplace == other.inplace self.inplace == other.inplace and
and self.set_instead_of_inc == other.set_instead_of_inc) self.set_instead_of_inc == other.set_instead_of_inc)
def __str__(self): def __str__(self):
return "%s{%s, %s}" % (self.__class__.__name__, return "%s{%s, %s}" % (self.__class__.__name__,
......
...@@ -57,7 +57,6 @@ whitelist_flake8 = [ ...@@ -57,7 +57,6 @@ whitelist_flake8 = [
"typed_list/tests/test_type.py", "typed_list/tests/test_type.py",
"typed_list/tests/test_opt.py", "typed_list/tests/test_opt.py",
"typed_list/tests/test_basic.py", "typed_list/tests/test_basic.py",
"tensor/subtensor.py",
"tensor/elemwise.py", "tensor/elemwise.py",
"tensor/xlogx.py", "tensor/xlogx.py",
"tensor/blas_headers.py", "tensor/blas_headers.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论