提交 59a91c4b authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.scalar

上级 e96f06da
...@@ -18,7 +18,6 @@ from itertools import chain ...@@ -18,7 +18,6 @@ from itertools import chain
from textwrap import dedent from textwrap import dedent
import numpy as np import numpy as np
import six
import theano import theano
from theano import config, gof, printing from theano import config, gof, printing
...@@ -41,8 +40,6 @@ class ComplexError(NotImplementedError): ...@@ -41,8 +40,6 @@ class ComplexError(NotImplementedError):
""" """
pass
class IntegerDivisionError(Exception): class IntegerDivisionError(Exception):
""" """
...@@ -50,8 +47,6 @@ class IntegerDivisionError(Exception): ...@@ -50,8 +47,6 @@ class IntegerDivisionError(Exception):
""" """
pass
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
# This tries to keep data in floatX or lower precision, unless we # This tries to keep data in floatX or lower precision, unless we
...@@ -135,7 +130,7 @@ def as_scalar(x, name=None): ...@@ -135,7 +130,7 @@ def as_scalar(x, name=None):
raise TypeError("Cannot convert %s to Scalar" % x, type(x)) raise TypeError("Cannot convert %s to Scalar" % x, type(x))
class NumpyAutocaster(object): class NumpyAutocaster:
""" """
This class is used to cast python ints and floats to numpy arrays. This class is used to cast python ints and floats to numpy arrays.
...@@ -171,7 +166,7 @@ class NumpyAutocaster(object): ...@@ -171,7 +166,7 @@ class NumpyAutocaster(object):
def __call__(self, x): def __call__(self, x):
# Make sure we only deal with scalars. # Make sure we only deal with scalars.
assert isinstance(x, (six.integer_types, builtin_float)) or ( assert isinstance(x, (int, builtin_float)) or (
isinstance(x, np.ndarray) and x.ndim == 0 isinstance(x, np.ndarray) and x.ndim == 0
) )
...@@ -230,7 +225,7 @@ autocast_int = NumpyAutocaster(("int8", "int16", "int32", "int64")) ...@@ -230,7 +225,7 @@ autocast_int = NumpyAutocaster(("int8", "int16", "int32", "int64"))
autocast_float = NumpyAutocaster(("float16", "float32", "float64")) autocast_float = NumpyAutocaster(("float16", "float32", "float64"))
class autocast_float_as(object): class autocast_float_as:
""" """
Temporarily adjust autocasting behavior. Temporarily adjust autocasting behavior.
...@@ -280,7 +275,7 @@ def convert(x, dtype=None): ...@@ -280,7 +275,7 @@ def convert(x, dtype=None):
# In this case, this function should infer the dtype according to the # In this case, this function should infer the dtype according to the
# autocasting rules. See autocasting above. # autocasting rules. See autocasting above.
x_ = None x_ = None
if isinstance(x, six.integer_types): if isinstance(x, int):
try: try:
x_ = autocast_int(x) x_ = autocast_int(x)
except OverflowError: except OverflowError:
...@@ -442,7 +437,9 @@ class Scalar(Type): ...@@ -442,7 +437,9 @@ class Scalar(Type):
}[self.dtype] }[self.dtype]
except KeyError: except KeyError:
raise TypeError( raise TypeError(
"Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype) "Unsupported dtype for {}: {}".format(
self.__class__.__name__, self.dtype
)
) )
def upcast(self, *others): def upcast(self, *others):
...@@ -1194,9 +1191,9 @@ class ScalarOp(Op): ...@@ -1194,9 +1191,9 @@ class ScalarOp(Op):
not in ["name", "_op_use_c_code", "bool", "output_types_preference"] not in ["name", "_op_use_c_code", "bool", "output_types_preference"]
] ]
if param: if param:
return "%s{%s}" % ( return "{}{{{}}}".format(
self.__class__.__name__, self.__class__.__name__,
", ".join("%s=%s" % (k, v) for k, v in param), ", ".join("{}={}".format(k, v) for k, v in param),
) )
else: else:
return self.__class__.__name__ return self.__class__.__name__
...@@ -1343,7 +1340,7 @@ class LogicalComparison(BinaryScalarOp): ...@@ -1343,7 +1340,7 @@ class LogicalComparison(BinaryScalarOp):
] ]
def c_code_cache_version(self): def c_code_cache_version(self):
super_version = super(LogicalComparison, self).c_code_cache_version() super_version = super().c_code_cache_version()
return super_version + (0,) return super_version + (0,)
...@@ -1376,7 +1373,7 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -1376,7 +1373,7 @@ class FixedLogicalComparison(UnaryScalarOp):
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
def c_code_cache_version(self): def c_code_cache_version(self):
super_version = super(FixedLogicalComparison, self).c_code_cache_version() super_version = super().c_code_cache_version()
return super_version + (0,) return super_version + (0,)
...@@ -1522,7 +1519,7 @@ class IsNan(FixedLogicalComparison): ...@@ -1522,7 +1519,7 @@ class IsNan(FixedLogicalComparison):
return "%(z)s = abs(isnan(%(x)s));" % locals() return "%(z)s = abs(isnan(%(x)s));" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
scalarop_version = super(IsNan, self).c_code_cache_version() scalarop_version = super().c_code_cache_version()
return tuple(scalarop_version) + (3,) return tuple(scalarop_version) + (3,)
...@@ -1550,7 +1547,7 @@ class IsInf(FixedLogicalComparison): ...@@ -1550,7 +1547,7 @@ class IsInf(FixedLogicalComparison):
return "%(z)s = abs(isinf(%(x)s));" % locals() return "%(z)s = abs(isinf(%(x)s));" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
scalarop_version = super(IsInf, self).c_code_cache_version() scalarop_version = super().c_code_cache_version()
return tuple(scalarop_version) + (3,) return tuple(scalarop_version) + (3,)
...@@ -1737,7 +1734,7 @@ class AND(BinaryBitOp): ...@@ -1737,7 +1734,7 @@ class AND(BinaryBitOp):
return "%(z)s = (%(x)s & %(y)s);" % locals() return "%(z)s = (%(x)s & %(y)s);" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
super_version = super(AND, self).c_code_cache_version() super_version = super().c_code_cache_version()
return super_version + (3,) return super_version + (3,)
...@@ -2060,7 +2057,7 @@ class TrueDiv(BinaryScalarOp): ...@@ -2060,7 +2057,7 @@ class TrueDiv(BinaryScalarOp):
if all(t in discrete_types for t in types): if all(t in discrete_types for t in types):
return [get_scalar_type(config.floatX)] return [get_scalar_type(config.floatX)]
else: else:
return super(TrueDiv, self).output_types(types) return super().output_types(types)
def impl(self, x, y): def impl(self, x, y):
x = np.asarray(x) x = np.asarray(x)
...@@ -2544,12 +2541,12 @@ class Cast(UnaryScalarOp): ...@@ -2544,12 +2541,12 @@ class Cast(UnaryScalarOp):
def __init__(self, o_type, name=None): def __init__(self, o_type, name=None):
if not isinstance(o_type, Scalar): if not isinstance(o_type, Scalar):
raise TypeError(o_type) raise TypeError(o_type)
super(Cast, self).__init__(specific_out(o_type), name=name) super().__init__(specific_out(o_type), name=name)
self.o_type = o_type self.o_type = o_type
self.ctor = getattr(np, o_type.dtype) self.ctor = getattr(np, o_type.dtype)
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.o_type.dtype) return "{}{{{}}}".format(self.__class__.__name__, self.o_type.dtype)
def clone_float32(self): def clone_float32(self):
if self.o_type == float16: if self.o_type == float16:
...@@ -2575,8 +2572,8 @@ class Cast(UnaryScalarOp): ...@@ -2575,8 +2572,8 @@ class Cast(UnaryScalarOp):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
if node.outputs[0].type == bool: if node.outputs[0].type == bool:
return "%s = (%s) ? 1 : 0;" % (z, x) return "{} = ({}) ? 1 : 0;".format(z, x)
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) return "{} = ({}){};".format(z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, inputs, gout): def grad(self, inputs, gout):
(x,) = inputs (x,) = inputs
...@@ -2587,7 +2584,7 @@ class Cast(UnaryScalarOp): ...@@ -2587,7 +2584,7 @@ class Cast(UnaryScalarOp):
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Cast, self).c_code_cache_version() s = super().c_code_cache_version()
if s: if s:
return (4,) + s return (4,) + s
else: else:
...@@ -2738,7 +2735,7 @@ class Sgn(UnaryScalarOp): ...@@ -2738,7 +2735,7 @@ class Sgn(UnaryScalarOp):
raise ComplexError("complex has no sgn") raise ComplexError("complex has no sgn")
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version() s = super().c_code_cache_version()
if s: if s:
return (4,) + s return (4,) + s
else: # if parent is unversioned, we are too else: # if parent is unversioned, we are too
...@@ -4108,7 +4105,7 @@ class Composite(ScalarOp): ...@@ -4108,7 +4105,7 @@ class Composite(ScalarOp):
This fct allow fix patch this. This fct allow fix patch this.
""" """
d = dict([(k, getattr(self, k)) for k in self.init_param]) d = {k: getattr(self, k) for k in self.init_param}
out = self.__class__(**d) out = self.__class__(**d)
if name: if name:
out.name = name out.name = name
...@@ -4163,7 +4160,7 @@ class Composite(ScalarOp): ...@@ -4163,7 +4160,7 @@ class Composite(ScalarOp):
i += 1 i += 1
name = "V%%(id)s_tmp%i" % i name = "V%%(id)s_tmp%i" % i
subd[output] = name subd[output] = name
_c_code += "%s %s;\n" % (output.type.dtype_specs()[1], name) _c_code += "{} {};\n".format(output.type.dtype_specs()[1], name)
s = node.op.c_code( s = node.op.c_code(
node, node,
self.nodenames[j], self.nodenames[j],
...@@ -4338,7 +4335,7 @@ class Composite(ScalarOp): ...@@ -4338,7 +4335,7 @@ class Composite(ScalarOp):
def make_node(self, *inputs): def make_node(self, *inputs):
if tuple([i.type for i in self.inputs]) == tuple([i.type for i in inputs]): if tuple([i.type for i in self.inputs]) == tuple([i.type for i in inputs]):
return super(Composite, self).make_node(*inputs) return super().make_node(*inputs)
else: else:
# Make a new op with the right input type. # Make a new op with the right input type.
assert len(inputs) == self.nin assert len(inputs) == self.nin
...@@ -4470,7 +4467,7 @@ class Composite(ScalarOp): ...@@ -4470,7 +4467,7 @@ class Composite(ScalarOp):
self.init_py_impls() self.init_py_impls()
class Compositef32(object): class Compositef32:
# This is a dict of scalar op classes that need special handling # This is a dict of scalar op classes that need special handling
special = {} special = {}
......
...@@ -40,7 +40,7 @@ class Erf(UnaryScalarOp): ...@@ -40,7 +40,7 @@ class Erf(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erf(x) return scipy.special.erf(x)
else: else:
super(Erf, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -77,7 +77,7 @@ class Erfc(UnaryScalarOp): ...@@ -77,7 +77,7 @@ class Erfc(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfc(x) return scipy.special.erfc(x)
else: else:
super(Erfc, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -129,7 +129,7 @@ class Erfcx(UnaryScalarOp): ...@@ -129,7 +129,7 @@ class Erfcx(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfcx(x) return scipy.special.erfcx(x)
else: else:
super(Erfcx, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -169,7 +169,7 @@ class Erfinv(UnaryScalarOp): ...@@ -169,7 +169,7 @@ class Erfinv(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfinv(x) return scipy.special.erfinv(x)
else: else:
super(Erfinv, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -206,7 +206,7 @@ class Erfcinv(UnaryScalarOp): ...@@ -206,7 +206,7 @@ class Erfcinv(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfcinv(x) return scipy.special.erfcinv(x)
else: else:
super(Erfcinv, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -247,7 +247,7 @@ class Gamma(UnaryScalarOp): ...@@ -247,7 +247,7 @@ class Gamma(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return Gamma.st_impl(x) return Gamma.st_impl(x)
else: else:
super(Gamma, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, gout): def L_op(self, inputs, outputs, gout):
(x,) = inputs (x,) = inputs
...@@ -289,7 +289,7 @@ class GammaLn(UnaryScalarOp): ...@@ -289,7 +289,7 @@ class GammaLn(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return GammaLn.st_impl(x) return GammaLn.st_impl(x)
else: else:
super(GammaLn, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -337,7 +337,7 @@ class Psi(UnaryScalarOp): ...@@ -337,7 +337,7 @@ class Psi(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return Psi.st_impl(x) return Psi.st_impl(x)
else: else:
super(Psi, self).impl(x) super().impl(x)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
(x,) = inputs (x,) = inputs
...@@ -434,7 +434,7 @@ class TriGamma(UnaryScalarOp): ...@@ -434,7 +434,7 @@ class TriGamma(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return TriGamma.st_impl(x) return TriGamma.st_impl(x)
else: else:
super(TriGamma, self).impl(x) super().impl(x)
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
raise NotImplementedError() raise NotImplementedError()
...@@ -526,7 +526,7 @@ class Chi2SF(BinaryScalarOp): ...@@ -526,7 +526,7 @@ class Chi2SF(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return Chi2SF.st_impl(x, k) return Chi2SF.st_impl(x, k)
else: else:
super(Chi2SF, self).impl(x, k) super().impl(x, k)
def c_support_code(self): def c_support_code(self):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
...@@ -570,7 +570,7 @@ class GammaInc(BinaryScalarOp): ...@@ -570,7 +570,7 @@ class GammaInc(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return GammaInc.st_impl(k, x) return GammaInc.st_impl(k, x)
else: else:
super(GammaInc, self).impl(k, x) super().impl(k, x)
def c_support_code(self): def c_support_code(self):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
...@@ -614,7 +614,7 @@ class GammaIncC(BinaryScalarOp): ...@@ -614,7 +614,7 @@ class GammaIncC(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return GammaIncC.st_impl(k, x) return GammaIncC.st_impl(k, x)
else: else:
super(GammaIncC, self).impl(k, x) super().impl(k, x)
def c_support_code(self): def c_support_code(self):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
...@@ -658,7 +658,7 @@ class GammaU(BinaryScalarOp): ...@@ -658,7 +658,7 @@ class GammaU(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return GammaU.st_impl(k, x) return GammaU.st_impl(k, x)
else: else:
super(GammaU, self).impl(k, x) super().impl(k, x)
def c_support_code(self): def c_support_code(self):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
...@@ -702,7 +702,7 @@ class GammaL(BinaryScalarOp): ...@@ -702,7 +702,7 @@ class GammaL(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return GammaL.st_impl(k, x) return GammaL.st_impl(k, x)
else: else:
super(GammaL, self).impl(k, x) super().impl(k, x)
def c_support_code(self): def c_support_code(self):
with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f: with open(os.path.join(os.path.dirname(__file__), "c_code", "gamma.c")) as f:
...@@ -746,7 +746,7 @@ class Jv(BinaryScalarOp): ...@@ -746,7 +746,7 @@ class Jv(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(v, x) return self.st_impl(v, x)
else: else:
super(Jv, self).impl(v, x) super().impl(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -775,7 +775,7 @@ class J1(UnaryScalarOp): ...@@ -775,7 +775,7 @@ class J1(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(x) return self.st_impl(x)
else: else:
super(J1, self).impl(x) super().impl(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -812,7 +812,7 @@ class J0(UnaryScalarOp): ...@@ -812,7 +812,7 @@ class J0(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(x) return self.st_impl(x)
else: else:
super(J0, self).impl(x) super().impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -849,7 +849,7 @@ class Iv(BinaryScalarOp): ...@@ -849,7 +849,7 @@ class Iv(BinaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(v, x) return self.st_impl(v, x)
else: else:
super(Iv, self).impl(v, x) super().impl(v, x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
v, x = inputs v, x = inputs
...@@ -878,7 +878,7 @@ class I1(UnaryScalarOp): ...@@ -878,7 +878,7 @@ class I1(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(x) return self.st_impl(x)
else: else:
super(I1, self).impl(x) super().impl(x)
def grad(self, inputs, grads): def grad(self, inputs, grads):
(x,) = inputs (x,) = inputs
...@@ -904,7 +904,7 @@ class I0(UnaryScalarOp): ...@@ -904,7 +904,7 @@ class I0(UnaryScalarOp):
if imported_scipy_special: if imported_scipy_special:
return self.st_impl(x) return self.st_impl(x)
else: else:
super(I0, self).impl(x) super().impl(x)
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
......
...@@ -17,7 +17,6 @@ way (as scan does) to create a shared variable of this kind. ...@@ -17,7 +17,6 @@ way (as scan does) to create a shared variable of this kind.
""" """
import numpy as np import numpy as np
from six import integer_types
from theano.compile import SharedVariable from theano.compile import SharedVariable
...@@ -51,7 +50,7 @@ def shared(value, name=None, strict=False, allow_downcast=None): ...@@ -51,7 +50,7 @@ def shared(value, name=None, strict=False, allow_downcast=None):
We implement this using 0-d tensors for now. We implement this using 0-d tensors for now.
""" """
if not isinstance(value, (np.number, float, integer_types, complex)): if not isinstance(value, (np.number, float, int, complex)):
raise TypeError() raise TypeError()
try: try:
dtype = value.dtype dtype = value.dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论