提交 73c6dc12 authored 作者: Frederic's avatar Frederic

some pep8

上级 0496d000
......@@ -24,7 +24,7 @@ import theano
from theano.compat import PY3
from theano import gof
from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph)
FunctionGraph)
from theano.gof.python25 import partial, all, any
from theano.configparser import config
......@@ -137,7 +137,7 @@ class Scalar(Type):
py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % (
self, py_type, data, type(data)), data)
self, py_type, data, type(data)), data)
try:
converted_data = py_type(data)
if (allow_downcast or
......@@ -148,10 +148,11 @@ class Scalar(Type):
return py_type(data)
else:
raise TypeError('Value cannot accurately be converted to dtype'
' (%s) and allow_downcast is not True' % self.dtype)
' (%s) and allow_downcast is not True' %
self.dtype)
except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % (
type(data), data, self.dtype), e)
type(data), data, self.dtype), e)
def values_eq_approx(self, a, b, tolerance=1e-4):
return abs(a - b) <= ((abs(a) + abs(b)) * tolerance)
......@@ -201,7 +202,7 @@ class Scalar(Type):
}[self.dtype]
except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (
self.__class__.__name__, self.dtype))
self.__class__.__name__, self.dtype))
def upcast(self, *others):
return upcast(*[x.dtype for x in [self] + list(others)])
......@@ -352,11 +353,11 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types) \
for ctype in cplx_types
for rtype in real_types) \
+ ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types
for ctype2 in cplx_types)
for ctype1 in cplx_types
for ctype2 in cplx_types)
# We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a
......@@ -374,8 +375,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype)
operator_plus = ''.join(operator_plus_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
for ctype in cplx_types
for rtype in real_types)
def operator_minus_real(mytype, othertype):
return '''
......@@ -387,8 +388,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype)
operator_minus = ''.join(operator_minus_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
for ctype in cplx_types
for rtype in real_types)
def operator_mul_real(mytype, othertype):
return '''
......@@ -400,15 +401,15 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype)
operator_mul = ''.join(operator_mul_real(ctype, rtype)
for ctype in cplx_types
for rtype in real_types)
for ctype in cplx_types
for rtype in real_types)
return template % dict(nbits=64, half_nbits=32) \
+ template % dict(nbits=128, half_nbits=64) \
+ operator_eq \
+ operator_plus \
+ operator_minus \
+ operator_mul
+ template % dict(nbits=128, half_nbits=64) \
+ operator_eq \
+ operator_plus \
+ operator_minus \
+ operator_mul
else:
return ""
......@@ -437,11 +438,11 @@ class Scalar(Type):
# Register C code for ViewOp on Scalars.
theano.compile.register_view_op_c_code(
Scalar,
"""
%(oname)s = %(iname)s;
""",
1)
Scalar,
"""
%(oname)s = %(iname)s;
""",
1)
int8 = Scalar('int8')
......@@ -777,17 +778,18 @@ class ScalarOp(Op):
if output_types_preference is not None:
if not callable(output_types_preference):
raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference))
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
self.__class__, output_types_preference)
self.output_types_preference = output_types_preference
def make_node(self, *inputs):
if self.nin >= 0:
if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \
% (self, len(inputs), str(inputs), self.nin))
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" %
self, len(inputs), str(inputs), self.nin)
inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input.
type for input in inputs])]
outputs = [t() for t in self.output_types([input.type
for input in inputs])]
if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
% (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
......@@ -895,6 +897,7 @@ class UnaryScalarOp(ScalarOp):
%(fct)s(n, x, z);
""" % locals()
class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to
......@@ -929,7 +932,7 @@ class FixedLogicalComparison(UnaryScalarOp):
return [int8]
def grad(self, inputs, output_gradients):
x ,= inputs
x, = inputs
out = self(x)
assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX)]
......@@ -1158,8 +1161,9 @@ class BinaryBitOp(BinaryScalarOp):
return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients):
a,b = inputs
return [a.zeros_like().astype(theano.config.floatX), b.zeros_like().astype(theano.config.floatX)]
a, b = inputs
return [a.zeros_like().astype(theano.config.floatX),
b.zeros_like().astype(theano.config.floatX)]
class OR(BinaryBitOp):
......@@ -1226,7 +1230,7 @@ class Maximum(BinaryScalarOp):
raise NotImplementedError()
# Test for both y>x and x>=y to detect NaN
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
......@@ -1257,7 +1261,7 @@ class Minimum(BinaryScalarOp):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
......@@ -1296,7 +1300,7 @@ class Add(ScalarOp):
for ii, inp in enumerate(inputs):
if hasattr(inp, 'zeros_like'):
retval.append(
inp.zeros_like().astype(theano.config.floatX))
inp.zeros_like().astype(theano.config.floatX))
else:
retval.append(grad_undefined(self, ii, inp))
else:
......@@ -1331,9 +1335,10 @@ class Mul(ScalarOp):
output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types:
if not gz.type in complex_types:
raise TypeError('Mul with output_type ' + str(output_type) +\
' expected gz type to be complex, got gz with type ' +\
str(gz.type))
raise TypeError(
'Mul with output_type ' + str(output_type) +
' expected gz type to be complex, got gz with type ' +
str(gz.type))
if output_type in discrete_types:
return [ipt.zeros_like().astype(theano.config.floatX)
......@@ -1353,7 +1358,7 @@ class Mul(ScalarOp):
retval += [yr * real(gz) + yi * imag(gz)]
else:
retval += [mul(*([gz] + utils.difference(inputs,
[input])))]
[input])))]
return retval
......@@ -1409,10 +1414,10 @@ def int_or_true_div(x_discrete, y_discrete):
"`x.__truediv__(y)`.")
elif config.int_division == 'int':
warnings.warn(
"Division of two integer types with x / y is deprecated, "
"please use x // y for an integer division.",
DeprecationWarning,
stacklevel=4)
"Division of two integer types with x / y is deprecated, "
"please use x // y for an integer division.",
DeprecationWarning,
stacklevel=4)
return 'int'
elif config.int_division == 'floatX':
return 'true'
......@@ -1482,8 +1487,8 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp):
complex_error = ComplexError(
"Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.")
"Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.")
def impl(self, x, y):
return x // y
......@@ -1564,8 +1569,8 @@ def mod_check(x, y):
class Mod(BinaryScalarOp):
complex_error = ComplexError(
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
"Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.")
def impl(self, x, y):
if isinstance(x, numpy.complex) or isinstance(y, numpy.complex):
......@@ -1863,7 +1868,7 @@ class Abs(UnaryScalarOp):
outputs = [float64()]
else:
outputs = [t() for t in self.output_types(
[input.type for input in inputs])]
[input.type for input in inputs])]
return Apply(self, inputs, outputs)
def impl(self, x):
......@@ -2739,12 +2744,12 @@ class Composite(ScalarOp):
subd[orphan] = orphan.type.c_literal(orphan.data)
else:
raise ValueError(
"All orphans in the fgraph to Composite must"
" be Constant instances.")
"All orphans in the fgraph to Composite must"
" be Constant instances.")
_c_code = "{\n"
self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
for j, n in enumerate(self.fgraph.toposort())]
for j, n in enumerate(self.fgraph.toposort())]
i = 0
for j, node in enumerate(self.fgraph.toposort()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论