提交 99012e14 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Flake8 + review fixes.

上级 a476757a
...@@ -1335,7 +1335,7 @@ def _float_ones_like(x): ...@@ -1335,7 +1335,7 @@ def _float_ones_like(x):
floating point dtype """ floating point dtype """
dtype = x.type.dtype dtype = x.type.dtype
if 'float' not in dtype: if dtype not in tensor.float_dtypes:
dtype = theano.config.floatX dtype = theano.config.floatX
return tensor.ones_like(x, dtype=dtype) return tensor.ones_like(x, dtype=dtype)
......
...@@ -288,7 +288,7 @@ class Scalar(Type): ...@@ -288,7 +288,7 @@ class Scalar(Type):
if 'complex' in self.dtype: if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.") raise NotImplementedError("No literal for complex values.")
if self.dtype == 'bool': if self.dtype == 'bool':
return '1' if b else '0' return '1' if data else '0'
return str(data) return str(data)
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
...@@ -683,11 +683,6 @@ complexs64 = _multi(complex64) ...@@ -683,11 +683,6 @@ complexs64 = _multi(complex64)
complexs128 = _multi(complex128) complexs128 = _multi(complex128)
# Using a class instead of a function makes it possible to deep-copy it.
# Note that currently only a few functions use this mechanism, because
# it is enough to make the test-suite pass. However, it may prove
# necessary to use this same mechanism in other places as well in the
# future.
def upcast_out(*types): def upcast_out(*types):
dtype = Scalar.upcast(*types) dtype = Scalar.upcast(*types)
return get_scalar_type(dtype), return get_scalar_type(dtype),
...@@ -698,7 +693,8 @@ def upgrade_to_float(*types): ...@@ -698,7 +693,8 @@ def upgrade_to_float(*types):
Upgrade any int types to float32 or float64 to avoid losing precision. Upgrade any int types to float32 or float64 to avoid losing precision.
""" """
conv = {int8: float32, conv = {bool: float32,
int8: float32,
int16: float32, int16: float32,
int32: float64, int32: float64,
int64: float64, int64: float64,
...@@ -1265,7 +1261,7 @@ switch = Switch() ...@@ -1265,7 +1261,7 @@ switch = Switch()
class UnaryBitOp(UnaryScalarOp): class UnaryBitOp(UnaryScalarOp):
def output_types(self, *input_types): def output_types(self, *input_types):
for i in input_types[0]: for i in input_types[0]:
if i not in ((bool,) + discrete_types): if i not in discrete_types:
raise TypeError('input to a BitOp must have type (u)int8, ' raise TypeError('input to a BitOp must have type (u)int8, '
'(u)int16, (u)int32 or (u)int64 or bool not %s' % i) '(u)int16, (u)int32 or (u)int64 or bool not %s' % i)
return upcast_out(*input_types[0]) return upcast_out(*input_types[0])
......
...@@ -3182,6 +3182,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -3182,6 +3182,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
for i in axis: for i in axis:
s = true_div(s, shp[i]) s = true_div(s, shp[i])
# This can happen when axis is an empty list/tuple
if s.dtype != shp.dtype and s.dtype in discrete_dtypes: if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
s = cast(s, shp.dtype) s = cast(s, shp.dtype)
......
...@@ -1716,7 +1716,7 @@ class All(CAReduce): ...@@ -1716,7 +1716,7 @@ class All(CAReduce):
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype is not "bool": if input.dtype != "bool":
input = theano.tensor.neq(input, 0) input = theano.tensor.neq(input, 0)
ret = super(All, self).make_node(input) ret = super(All, self).make_node(input)
return ret return ret
...@@ -1746,7 +1746,7 @@ class Any(CAReduce): ...@@ -1746,7 +1746,7 @@ class Any(CAReduce):
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype is not "bool": if input.dtype != "bool":
input = theano.tensor.neq(input, 0) input = theano.tensor.neq(input, 0)
ret = super(Any, self).make_node(input) ret = super(Any, self).make_node(input)
return ret return ret
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论