提交 b3a291fa authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make sure the scalar ops handle bool correctly and don't produce irregular booleans.

上级 33bd55ef
......@@ -689,6 +689,13 @@ def upcast_out(*types):
return get_scalar_type(dtype),
def upcast_out_nobool(*types):
type = upcast_out(*types)
if type[0] == bool:
return int8,
return type
def upgrade_to_float(*types):
"""
Upgrade any int types to float32 or float64 to avoid losing precision.
......@@ -710,6 +717,11 @@ def upgrade_to_float(*types):
def same_out(type):
return type,
def same_out_nobool(type):
if type == bool:
return int8,
return type,
def upcast_out_no_complex(*types):
if any([type in complex_types for type in types]):
......@@ -1474,10 +1486,13 @@ class Add(ScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs
op = " + "
if z.type == bool:
op = " || "
if not inputs:
return z + " = 0;"
else:
return z + " = " + " + ".join(inputs) + ";"
return z + " = " + op.join(inputs) + ";"
def grad(self, inputs, gout):
(gz,) = gout
......@@ -1513,10 +1528,13 @@ class Mul(ScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs
op = " * "
if z.type == bool:
op = " && "
if not inputs:
return z + " = 1;"
else:
return z + " = " + " * ".join(inputs) + ";"
return z + " = " + op.join(inputs) + ";"
def grad(self, inputs, gout):
(gz,) = gout
......@@ -1566,6 +1584,9 @@ class Sub(BinaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
(z,) = outputs
if z.type == bool:
# xor
return "%(z)s = (%(x)s || %(y)s) && !(%(x)s && %(y)s);" % locals()
return "%(z)s = %(x)s - %(y)s;" % locals()
def grad(self, inputs, gout):
......@@ -1948,7 +1969,7 @@ class Pow(BinaryScalarOp):
raise theano.gof.utils.MethodNotDefined()
pow = Pow(upcast_out, name='pow')
pow = Pow(upcast_out_nobool, name='pow')
class Clip(ScalarOp):
......@@ -2073,6 +2094,8 @@ class Cast(UnaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if z.type == bool:
return "%s = (%s) ? 1 : 0" % (z, x)
return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x)
def grad(self, inputs, gout):
......@@ -2186,6 +2209,12 @@ abs_ = Abs(same_out)
class Sgn(UnaryScalarOp):
nfunc_spec = ('sign', 1, 1)
@staticmethod
def output_types_preference(x):
if x == bool:
raise TypeError(x)
return same_out_nocomplex(x)
def impl(self, x):
# casting to output type is handled by filter
return numpy.sign(x)
......@@ -2218,7 +2247,7 @@ class Sgn(UnaryScalarOp):
return (4,) + s
else: # if parent is unversioned, we are too
return s
sgn = Sgn(same_out_nocomplex, name='sgn')
sgn = Sgn(name='sgn')
class Ceil(UnaryScalarOp):
......@@ -2241,7 +2270,7 @@ class Ceil(UnaryScalarOp):
(x,) = inputs
(z,) = outputs
return "%(z)s = ceil(%(x)s);" % locals()
ceil = Ceil(same_out_nocomplex, name='ceil')
ceil = Ceil(upgrade_to_float_no_complex, name='ceil')
class Floor(UnaryScalarOp):
......@@ -2264,7 +2293,7 @@ class Floor(UnaryScalarOp):
(x,) = inputs
(z,) = outputs
return "%(z)s = floor(%(x)s);" % locals()
floor = Floor(same_out_nocomplex, name='floor')
floor = Floor(upgrade_to_float_no_complex, name='floor')
class Trunc(UnaryScalarOp):
......@@ -2282,7 +2311,7 @@ class Trunc(UnaryScalarOp):
(x,) = inputs
(z,) = outputs
return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals()
trunc = Trunc(same_out_nocomplex, name='trunc')
trunc = Trunc(upgrade_to_float_no_complex, name='trunc')
class RoundHalfToEven(UnaryScalarOp):
......@@ -2422,13 +2451,6 @@ class Neg(UnaryScalarOp):
nfunc_spec = ('negative', 1, 1)
def impl(self, x):
# We have to make sure x is not a numpy.bool_, because
# `-numpy.bool_(True)` is `False` (we want 0), and
# `-numpy.bool_(False)` is `True` (we want 1).
# This happens for Composite, as the intermediate results are not
# casted in the dtype of the intermediate variable in general.
if isinstance(x, numpy.bool_):
x = numpy.int8(x)
return -x
def grad(self, inputs, gout):
......@@ -2445,6 +2467,8 @@ class Neg(UnaryScalarOp):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
if z.type == bool:
return "%(z)s = !%(x)s;" % locals()
return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name='neg')
......@@ -3436,7 +3460,7 @@ class Conj(UnaryScalarOp):
def impl(self, x):
return numpy.conj(x)
conj = Conj(same_out, name='conj')
conj = Conj(same_out_nobool, name='conj')
class ComplexFromPolar(BinaryScalarOp):
......
......@@ -1065,7 +1065,7 @@ second dimension
# We loop over the "aliased" outputs, i.e., those that are
# inplace (overwrite the contents of one of the inputs) and
# make the output pointers point to theur corresponding input
# make the output pointers point to their corresponding input
# pointers.
for output, oname in izip(aliased_outputs, aliased_onames):
olv_index = inputs.index(dmap[output][0])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论