提交 9d766dc0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 scalar/basic.py

上级 5dd297be
...@@ -451,8 +451,11 @@ class _scalar_py_operators: ...@@ -451,8 +451,11 @@ class _scalar_py_operators:
ndim = 0 ndim = 0
#UNARY #UNARY
def __abs__(self): return abs_(self) def __abs__(self):
def __neg__(self): return neg(self) return abs_(self)
def __neg__(self):
return neg(self)
#CASTS #CASTS
#def __int__(self): return AsInt(self).out #def __int__(self): return AsInt(self).out
...@@ -460,36 +463,80 @@ class _scalar_py_operators: ...@@ -460,36 +463,80 @@ class _scalar_py_operators:
#def __complex__(self): return AsComplex(self).out #def __complex__(self): return AsComplex(self).out
#BITWISE #BITWISE
def __invert__(self): return invert(self) def __invert__(self):
def __and__(self,other): return and_(self, other) return invert(self)
def __or__(self,other): return or_(self, other)
def __xor__(self,other): return xor(self, other) def __and__(self, other):
def __rand__(self,other): return and_(other,self) return and_(self, other)
def __ror__(self,other): return or_(other, self)
def __rxor__(self,other): return xor(other, self) def __or__(self, other):
return or_(self, other)
def __xor__(self, other):
return xor(self, other)
def __rand__(self, other):
return and_(other, self)
def __ror__(self, other):
return or_(other, self)
def __rxor__(self, other):
return xor(other, self)
#COMPARISONS #COMPARISONS
def __lt__(self,other): return lt(self, other) def __lt__(self, other):
def __le__(self,other): return le(self, other) return lt(self, other)
def __gt__(self,other): return gt(self, other)
def __ge__(self,other): return ge(self, other) def __le__(self, other):
return le(self, other)
def __gt__(self, other):
return gt(self, other)
def __ge__(self, other):
return ge(self, other)
#ARITHMETIC - NORMAL #ARITHMETIC - NORMAL
def __add__(self,other): return add(self,other) def __add__(self, other):
def __sub__(self,other): return sub(self,other) return add(self, other)
def __mul__(self,other): return mul(self,other)
def __div__(self,other): return div_proxy(self,other) def __sub__(self, other):
def __floordiv__(self, other): return int_div(self, other) return sub(self, other)
def __mod__(self, other): return mod_check(self, other)
def __pow__(self,other): return pow(self,other) def __mul__(self, other):
return mul(self, other)
def __div__(self, other):
return div_proxy(self, other)
def __floordiv__(self, other):
return int_div(self, other)
def __mod__(self, other):
return mod_check(self, other)
def __pow__(self, other):
return pow(self, other)
#ARITHMETIC - RIGHT-OPERAND #ARITHMETIC - RIGHT-OPERAND
def __radd__(self,other): return add(other,self) def __radd__(self, other):
def __rsub__(self,other): return sub(other,self) return add(other, self)
def __rmul__(self,other): return mul(other,self)
def __rdiv__(self,other): return div_proxy(other,self) def __rsub__(self, other):
def __rmod__(self,other): return mod(other,self) return sub(other, self)
def __rpow__(self,other): return pow(other,self)
def __rmul__(self, other):
return mul(other, self)
def __rdiv__(self, other):
return div_proxy(other, self)
def __rmod__(self, other):
return mod(other, self)
def __rpow__(self, other):
return pow(other, self)
def zeros_like(self): def zeros_like(self):
# The second is needed for Elemwise ops to work right # The second is needed for Elemwise ops to work right
...@@ -697,7 +744,8 @@ class ScalarOp(Op): ...@@ -697,7 +744,8 @@ class ScalarOp(Op):
self.name = name self.name = name
if output_types_preference is not None: if output_types_preference is not None:
if not callable(output_types_preference): if not callable(output_types_preference):
raise TypeError("Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference)) raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference))
self.output_types_preference = output_types_preference self.output_types_preference = output_types_preference
def make_node(self, *inputs): def make_node(self, *inputs):
...@@ -706,7 +754,8 @@ class ScalarOp(Op): ...@@ -706,7 +754,8 @@ class ScalarOp(Op):
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \
% (self, len(inputs), str(inputs), self.nin)) % (self, len(inputs), str(inputs), self.nin))
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input.type for input in inputs])] outputs = [t() for t in self.output_types([input.
type for input in inputs])]
if len(outputs) != self.nout: if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s." raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
% (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs))) % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
...@@ -716,7 +765,8 @@ class ScalarOp(Op): ...@@ -716,7 +765,8 @@ class ScalarOp(Op):
if hasattr(self, 'output_types_preference'): if hasattr(self, 'output_types_preference'):
variables = self.output_types_preference(*types) variables = self.output_types_preference(*types)
if not isinstance(variables, (list, tuple)) or any(not isinstance(x, Type) for x in variables): if not isinstance(variables, (list, tuple)) or any(not isinstance(x, Type) for x in variables):
raise TypeError("output_types_preference should return a list or a tuple of types", self.output_types_preference, variables) raise TypeError(
"output_types_preference should return a list or a tuple of types", self.output_types_preference, variables)
if len(variables) != self.nout: if len(variables) != self.nout:
raise TypeError("Not the right number of outputs types produced for %s(%s) by %s. Expected %s, got %s." raise TypeError("Not the right number of outputs types produced for %s(%s) by %s. Expected %s, got %s."
% (self, ", ".join(str(type) for type in variables), % (self, ", ".join(str(type) for type in variables),
...@@ -1100,7 +1150,7 @@ class Maximum(BinaryScalarOp): ...@@ -1100,7 +1150,7 @@ class Maximum(BinaryScalarOp):
assert gz.type not in complex_types assert gz.type not in complex_types
# max is not defined for complex_types # max is not defined for complex_types
output = self(x,y) output = self(x, y)
if output.type in discrete_types: if output.type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
...@@ -1130,7 +1180,7 @@ class Minimum(BinaryScalarOp): ...@@ -1130,7 +1180,7 @@ class Minimum(BinaryScalarOp):
assert gz.type not in complex_types assert gz.type not in complex_types
# max is not defined for complex_types # max is not defined for complex_types
output = minimum(x,y) output = minimum(x, y)
if output.type in discrete_types: if output.type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
...@@ -1197,8 +1247,8 @@ class Mul(ScalarOp): ...@@ -1197,8 +1247,8 @@ class Mul(ScalarOp):
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
if not gz.type in complex_types: if not gz.type in complex_types:
raise TypeError('Mul with output_type '+str(output_type)+\ raise TypeError('Mul with output_type ' + str(output_type) +\
' expected gz type to be complex, got gz with type '+\ ' expected gz type to be complex, got gz with type ' +\
str(gz.type)) str(gz.type))
if output_type in discrete_types: if output_type in discrete_types:
...@@ -1237,7 +1287,7 @@ class Sub(BinaryScalarOp): ...@@ -1237,7 +1287,7 @@ class Sub(BinaryScalarOp):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if (x-y).type in discrete_types: if (x - y).type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
...@@ -1331,8 +1381,8 @@ class TrueDiv(BinaryScalarOp): ...@@ -1331,8 +1381,8 @@ class TrueDiv(BinaryScalarOp):
# This is different from it not being connected # This is different from it not being connected
# to the output; x/y is still a function of x # to the output; x/y is still a function of x
# and y; it's just a step function. # and y; it's just a step function.
if (x/y).type in discrete_types: if (x / y).type in discrete_types:
return [ x.zeros_like(), y.zeros_like() ] return [x.zeros_like(), y.zeros_like()]
first_part = gz / y first_part = gz / y
...@@ -1516,7 +1566,7 @@ class Pow(BinaryScalarOp): ...@@ -1516,7 +1566,7 @@ class Pow(BinaryScalarOp):
if gz.type in complex_types: if gz.type in complex_types:
raise NotImplementedError() raise NotImplementedError()
if self(x,y).type in discrete_types: if self(x, y).type in discrete_types:
return [x.zeros_like().astype(theano.config.floatX), return [x.zeros_like().astype(theano.config.floatX),
y.zeros_like().astype(theano.config.floatX)] y.zeros_like().astype(theano.config.floatX)]
...@@ -1567,7 +1617,7 @@ class Second(BinaryScalarOp): ...@@ -1567,7 +1617,7 @@ class Second(BinaryScalarOp):
# x is never connected because its elements are never used # x is never connected because its elements are never used
# y is connected because its elements are copied over # y is connected because its elements are copied over
return [[False],[True]] return [[False], [True]]
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
...@@ -1619,9 +1669,9 @@ class Cast(UnaryScalarOp): ...@@ -1619,9 +1669,9 @@ class Cast(UnaryScalarOp):
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
if self.o_type in continuous_types: if self.o_type in continuous_types:
return [ gz ] return [gz]
else: else:
return [ x.zeros_like().astype(theano.config.floatX) ] return [x.zeros_like().astype(theano.config.floatX)]
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Cast, self).c_code_cache_version() s = super(Cast, self).c_code_cache_version()
...@@ -1777,7 +1827,7 @@ class Trunc(UnaryScalarOp): ...@@ -1777,7 +1827,7 @@ class Trunc(UnaryScalarOp):
return numpy.trunc(x) return numpy.trunc(x)
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
return [ x.zeros_like().astype(theano.config.floatX) ] return [x.zeros_like().astype(theano.config.floatX)]
def c_code(self, node, name, (x,), (z,), sub): def c_code(self, node, name, (x,), (z,), sub):
return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals() return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals()
...@@ -2674,7 +2724,7 @@ class Composite(ScalarOp): ...@@ -2674,7 +2724,7 @@ class Composite(ScalarOp):
onames), onames),
**sub) **sub)
d['nodename'] = nodename d['nodename'] = nodename
if not sub.has_key('id'): if not 'id' in sub:
#The use of a dummy id is safe as the code is in a separate block. #The use of a dummy id is safe as the code is in a separate block.
#It won't generate conflicting variable name. #It won't generate conflicting variable name.
d['id'] = '_DUMMY_ID_' d['id'] = '_DUMMY_ID_'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论