提交 73c6dc12 authored 作者: Frederic's avatar Frederic

some pep8

上级 0496d000
...@@ -148,7 +148,8 @@ class Scalar(Type): ...@@ -148,7 +148,8 @@ class Scalar(Type):
return py_type(data) return py_type(data)
else: else:
raise TypeError('Value cannot accurately be converted to dtype' raise TypeError('Value cannot accurately be converted to dtype'
' (%s) and allow_downcast is not True' % self.dtype) ' (%s) and allow_downcast is not True' %
self.dtype)
except Exception, e: except Exception, e:
raise TypeError("Could not convert %s (value=%s) to %s" % ( raise TypeError("Could not convert %s (value=%s) to %s" % (
type(data), data, self.dtype), e) type(data), data, self.dtype), e)
...@@ -777,17 +778,18 @@ class ScalarOp(Op): ...@@ -777,17 +778,18 @@ class ScalarOp(Op):
if output_types_preference is not None: if output_types_preference is not None:
if not callable(output_types_preference): if not callable(output_types_preference):
raise TypeError( raise TypeError(
"Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % (self.__class__, output_types_preference)) "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" %
self.__class__, output_types_preference)
self.output_types_preference = output_types_preference self.output_types_preference = output_types_preference
def make_node(self, *inputs): def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" %
% (self, len(inputs), str(inputs), self.nin)) self, len(inputs), str(inputs), self.nin)
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input. outputs = [t() for t in self.output_types([input.type
type for input in inputs])] for input in inputs])]
if len(outputs) != self.nout: if len(outputs) != self.nout:
raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s." raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s."
% (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs))) % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs)))
...@@ -895,6 +897,7 @@ class UnaryScalarOp(ScalarOp): ...@@ -895,6 +897,7 @@ class UnaryScalarOp(ScalarOp):
%(fct)s(n, x, z); %(fct)s(n, x, z);
""" % locals() """ % locals()
class BinaryScalarOp(ScalarOp): class BinaryScalarOp(ScalarOp):
# One may define in subclasses the following fields: # One may define in subclasses the following fields:
# - `identity`: for an associative operation, identity corresponds to # - `identity`: for an associative operation, identity corresponds to
...@@ -929,7 +932,7 @@ class FixedLogicalComparison(UnaryScalarOp): ...@@ -929,7 +932,7 @@ class FixedLogicalComparison(UnaryScalarOp):
return [int8] return [int8]
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
x ,= inputs x, = inputs
out = self(x) out = self(x)
assert str(out.type.dtype).find('int') != -1 assert str(out.type.dtype).find('int') != -1
return [x.zeros_like().astype(theano.config.floatX)] return [x.zeros_like().astype(theano.config.floatX)]
...@@ -1158,8 +1161,9 @@ class BinaryBitOp(BinaryScalarOp): ...@@ -1158,8 +1161,9 @@ class BinaryBitOp(BinaryScalarOp):
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
def grad(self, inputs, output_gradients): def grad(self, inputs, output_gradients):
a,b = inputs a, b = inputs
return [a.zeros_like().astype(theano.config.floatX), b.zeros_like().astype(theano.config.floatX)] return [a.zeros_like().astype(theano.config.floatX),
b.zeros_like().astype(theano.config.floatX)]
class OR(BinaryBitOp): class OR(BinaryBitOp):
...@@ -1331,8 +1335,9 @@ class Mul(ScalarOp): ...@@ -1331,8 +1335,9 @@ class Mul(ScalarOp):
output_type = self.output_types([i.type for i in inputs])[0] output_type = self.output_types([i.type for i in inputs])[0]
if output_type in complex_types: if output_type in complex_types:
if not gz.type in complex_types: if not gz.type in complex_types:
raise TypeError('Mul with output_type ' + str(output_type) +\ raise TypeError(
' expected gz type to be complex, got gz with type ' +\ 'Mul with output_type ' + str(output_type) +
' expected gz type to be complex, got gz with type ' +
str(gz.type)) str(gz.type))
if output_type in discrete_types: if output_type in discrete_types:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论