提交 73c6dc12 authored 作者: Frederic's avatar Frederic

some pep8

上级 0496d000
...@@ -24,7 +24,7 @@ import theano ...@@ -24,7 +24,7 @@ import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof
from theano.gof import (Op, utils, Variable, Constant, Type, Apply, from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph) FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
...@@ -137,7 +137,7 @@ class Scalar(Type): ...@@ -137,7 +137,7 @@ class Scalar(Type):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
if strict and not isinstance(data, py_type): if strict and not isinstance(data, py_type):
raise TypeError("%s expected a %s, got %s of type %s" % ( raise TypeError("%s expected a %s, got %s of type %s" % (
self, py_type, data, type(data)), data) self, py_type, data, type(data)), data)
try: try:
converted_data = py_type(data) converted_data = py_type(data)
if (allow_downcast or if (allow_downcast or
...@@ -148,10 +148,11 @@ class Scalar(Type): ...@@ -148,10 +148,11 @@ class Scalar(Type):
return py_type(data) return py_type(data)
else: else:
raise TypeError('Value cannot accurately be converted to dtype' raise TypeError('Value cannot accurately be converted to dtype'
' (%s) and allow_downcast is not True' % self.dtype) ' (%s) and allow_downcast is not True' %
self.dtype)
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % ( raise TypeError("Could not convert %s (value=%s) to %s" % (
type(data), data, self.dtype), e) type(data), data, self.dtype), e)
def values_eq_approx(self, a, b, tolerance=1e-4): def values_eq_approx(self, a, b, tolerance=1e-4):
return abs(a - b) <= ((abs(a) + abs(b)) * tolerance) return abs(a - b) <= ((abs(a) + abs(b)) * tolerance)
...@@ -201,7 +202,7 @@ class Scalar(Type): ...@@ -201,7 +202,7 @@ class Scalar(Type):
}[self.dtype] }[self.dtype]
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % ( raise TypeError("Unsupported dtype for %s: %s" % (
self.__class__.__name__, self.dtype)) self.__class__.__name__, self.dtype))
def upcast(self, *others): def upcast(self, *others):
return upcast(*[x.dtype for x in [self] + list(others)]) return upcast(*[x.dtype for x in [self] + list(others)])
...@@ -352,11 +353,11 @@ class Scalar(Type): ...@@ -352,11 +353,11 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_eq = ''.join(operator_eq_real(ctype, rtype) operator_eq = ''.join(operator_eq_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) \ for rtype in real_types) \
+ ''.join(operator_eq_cplx(ctype1, ctype2) + ''.join(operator_eq_cplx(ctype1, ctype2)
for ctype1 in cplx_types for ctype1 in cplx_types
for ctype2 in cplx_types) for ctype2 in cplx_types)
# We are not using C++ generic templating here, because this would # We are not using C++ generic templating here, because this would
# generate two different functions for adding a complex64 and a # generate two different functions for adding a complex64 and a
...@@ -374,8 +375,8 @@ class Scalar(Type): ...@@ -374,8 +375,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_plus = ''.join(operator_plus_real(ctype, rtype) operator_plus = ''.join(operator_plus_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
def operator_minus_real(mytype, othertype): def operator_minus_real(mytype, othertype):
return ''' return '''
...@@ -387,8 +388,8 @@ class Scalar(Type): ...@@ -387,8 +388,8 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_minus = ''.join(operator_minus_real(ctype, rtype) operator_minus = ''.join(operator_minus_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
def operator_mul_real(mytype, othertype): def operator_mul_real(mytype, othertype):
return ''' return '''
...@@ -400,15 +401,15 @@ class Scalar(Type): ...@@ -400,15 +401,15 @@ class Scalar(Type):
''' % dict(mytype=mytype, othertype=othertype) ''' % dict(mytype=mytype, othertype=othertype)
operator_mul = ''.join(operator_mul_real(ctype, rtype) operator_mul = ''.join(operator_mul_real(ctype, rtype)
for ctype in cplx_types for ctype in cplx_types
for rtype in real_types) for rtype in real_types)
return template % dict(nbits=64, half_nbits=32) \ return template % dict(nbits=64, half_nbits=32) \
+ template % dict(nbits=128, half_nbits=64) \ + template % dict(nbits=128, half_nbits=64) \
+ operator_eq \ + operator_eq \
+ operator_plus \ + operator_plus \
+ operator_minus \ + operator_minus \
+ operator_mul + operator_mul
else: else:
return "" return ""
...@@ -437,11 +438,11 @@ class Scalar(Type): ...@@ -437,11 +438,11 @@ class Scalar(Type):
# Register C code for ViewOp on Scalars. # Register C code for ViewOp on Scalars.
theano.compile.register_view_op_c_code( theano.compile.register_view_op_c_code(
Scalar, Scalar,
""" """
%(oname)s = %(iname)s; %(oname)s = %(iname)s;
""", """,
1) 1)
int8 = Scalar('int8') int8 = Scalar('int8')
...@@ -777,17 +778,18 @@ class ScalarOp(Op): ...@@ -777,17 +778,18 @@ class ScalarOp(Op):
if output_types_preference is not None: if output_types_preference is not None:
if not callable(output_types_preference): if not callable(output_types_preference):
raise TypeError( raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference)) "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
self.__class__, output_types_preference)
self.output_types_preference = output_types_preference self.output_types_preference = output_types_preference
def make_node(self, *inputs): def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" %
% (self, len(inputs), str(inputs), self.nin)) self, len(inputs), str(inputs), self.nin)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input. outputs = [t() for t in self.output_types([input.type
type for input in inputs])] for input in inputs])]
if len(outputs) != self.nout: if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s." raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
% (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs))) % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
...@@ -895,6 +897,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -895,6 +897,7 @@ class UnaryScalarOp(ScalarOp):
%(fct)s(n, x, z); %(fct)s(n, x, z);
""" % locals() """ % locals()
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields: # One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to # - `identity`: for an associative operation, identity corresponds to
...@@ -929,7 +932,7 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -929,7 +932,7 @@ class FixedLogicalComparison(UnaryScalarOp):
return [int8] return [int8]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
x ,= inputs x, = inputs
out = self(x) out = self(x)
assert str(out.type.dtype).find('int') != -1 assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
...@@ -1158,8 +1161,9 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -1158,8 +1161,9 @@ class BinaryBitOp(BinaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
a,b = inputs a, b = inputs
return [a.zeros_like().astype(theano.config.floatX), b.zeros_like().astype(theano.config.floatX)] return [a.zeros_like().astype(theano.config.floatX),
b.zeros_like().astype(theano.config.floatX)]
class OR(BinaryBitOp): class OR(BinaryBitOp):
...@@ -1226,7 +1230,7 @@ class Maximum(BinaryScalarOp): ...@@ -1226,7 +1230,7 @@ class Maximum(BinaryScalarOp):
raise NotImplementedError() raise NotImplementedError()
# Test for both y>x and x>=y to detect NaN # Test for both y>x and x>=y to detect NaN
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
...@@ -1257,7 +1261,7 @@ class Minimum(BinaryScalarOp): ...@@ -1257,7 +1261,7 @@ class Minimum(BinaryScalarOp):
if any([i.type in complex_types for i in node.inputs]): if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError() raise NotImplementedError()
return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): ' return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals()) '((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
...@@ -1296,7 +1300,7 @@ class Add(ScalarOp): ...@@ -1296,7 +1300,7 @@ class Add(ScalarOp):
for ii, inp in enumerate(inputs): for ii, inp in enumerate(inputs):
if hasattr(inp, 'zeros_like'): if hasattr(inp, 'zeros_like'):
retval.append( retval.append(
inp.zeros_like().astype(theano.config.floatX)) inp.zeros_like().astype(theano.config.floatX))
else: else:
retval.append(grad_undefined(self, ii, inp)) retval.append(grad_undefined(self, ii, inp))
else: else:
...@@ -1331,9 +1335,10 @@ class Mul(ScalarOp): ...@@ -1331,9 +1335,10 @@ class Mul(ScalarOp):
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
if not gz.type in complex_types: if not gz.type in complex_types:
raise TypeError('Mul with output_type ' + str(output_type) +\ raise TypeError(
' expected gz type to be complex, got gz with type ' +\ 'Mul with output_type ' + str(output_type) +
str(gz.type)) ' expected gz type to be complex, got gz with type ' +
str(gz.type))
if output_type in discrete_types: if output_type in discrete_types:
return [ipt.zeros_like().astype(theano.config.floatX) return [ipt.zeros_like().astype(theano.config.floatX)
...@@ -1353,7 +1358,7 @@ class Mul(ScalarOp): ...@@ -1353,7 +1358,7 @@ class Mul(ScalarOp):
retval += [yr * real(gz) + yi * imag(gz)] retval += [yr * real(gz) + yi * imag(gz)]
else: else:
retval += [mul(*([gz] + utils.difference(inputs, retval += [mul(*([gz] + utils.difference(inputs,
[input])))] [input])))]
return retval return retval
...@@ -1409,10 +1414,10 @@ def int_or_true_div(x_discrete, y_discrete): ...@@ -1409,10 +1414,10 @@ def int_or_true_div(x_discrete, y_discrete):
"`x.__truediv__(y)`.") "`x.__truediv__(y)`.")
elif config.int_division == 'int': elif config.int_division == 'int':
warnings.warn( warnings.warn(
"Division of two integer types with x / y is deprecated, " "Division of two integer types with x / y is deprecated, "
"please use x // y for an integer division.", "please use x // y for an integer division.",
DeprecationWarning, DeprecationWarning,
stacklevel=4) stacklevel=4)
return 'int' return 'int'
elif config.int_division == 'floatX': elif config.int_division == 'floatX':
return 'true' return 'true'
...@@ -1482,8 +1487,8 @@ true_div = TrueDiv(upcast_out, name='true_div') ...@@ -1482,8 +1487,8 @@ true_div = TrueDiv(upcast_out, name='true_div')
class IntDiv(BinaryScalarOp): class IntDiv(BinaryScalarOp):
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support integer division (//) on " "Theano does not support integer division (//) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
def impl(self, x, y): def impl(self, x, y):
return x // y return x // y
...@@ -1564,8 +1569,8 @@ def mod_check(x, y): ...@@ -1564,8 +1569,8 @@ def mod_check(x, y):
class Mod(BinaryScalarOp): class Mod(BinaryScalarOp):
complex_error = ComplexError( complex_error = ComplexError(
"Theano does not support the mod operator (%) on " "Theano does not support the mod operator (%) on "
"complex numbers, since numpy deprecated it.") "complex numbers, since numpy deprecated it.")
def impl(self, x, y): def impl(self, x, y):
if isinstance(x, numpy.complex) or isinstance(y, numpy.complex): if isinstance(x, numpy.complex) or isinstance(y, numpy.complex):
...@@ -1863,7 +1868,7 @@ class Abs(UnaryScalarOp): ...@@ -1863,7 +1868,7 @@ class Abs(UnaryScalarOp):
outputs = [float64()] outputs = [float64()]
else: else:
outputs = [t() for t in self.output_types( outputs = [t() for t in self.output_types(
[input.type for input in inputs])] [input.type for input in inputs])]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def impl(self, x): def impl(self, x):
...@@ -2739,12 +2744,12 @@ class Composite(ScalarOp): ...@@ -2739,12 +2744,12 @@ class Composite(ScalarOp):
subd[orphan] = orphan.type.c_literal(orphan.data) subd[orphan] = orphan.type.c_literal(orphan.data)
else: else:
raise ValueError( raise ValueError(
"All orphans in the fgraph to Composite must" "All orphans in the fgraph to Composite must"
" be Constant instances.") " be Constant instances.")
_c_code = "{\n" _c_code = "{\n"
self.nodenames = ["%(nodename)s_" + ('subnode%i' % j) self.nodenames = ["%(nodename)s_" + ('subnode%i' % j)
for j, n in enumerate(self.fgraph.toposort())] for j, n in enumerate(self.fgraph.toposort())]
i = 0 i = 0
for j, node in enumerate(self.fgraph.toposort()): for j, node in enumerate(self.fgraph.toposort()):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论