提交 74b1abaa authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix broken dtype checks in aesara.scalar.basic

上级 7d6dfcb4
......@@ -757,6 +757,8 @@ discrete_types = (bool,) + integer_types
continuous_types = float_types + complex_types
all_types = discrete_types + continuous_types
discrete_dtypes = tuple(t.dtype for t in discrete_types)
class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar
......@@ -1596,7 +1598,7 @@ class Switch(ScalarOp):
first_part = switch(cond, gz, 0.0)
second_part = switch(cond, 0.0, gz)
if outputs[0].type.dtype in discrete_types:
if outputs[0].type in discrete_types:
first_part = 0.0
second_part = 0.0
......@@ -1991,7 +1993,7 @@ class TrueDiv(BinaryScalarOp):
def impl(self, x, y):
x = np.asarray(x)
y = np.asarray(y)
if all(a.dtype in discrete_types for a in (x, y)):
if all(a.dtype in discrete_dtypes for a in (x, y)):
return np.sctypeDict[config.floatX](float(x) / y)
else:
return x / y
......@@ -2030,7 +2032,7 @@ class TrueDiv(BinaryScalarOp):
# This is different from it not being connected
# to the output; x/y is still a function of x
# and y; it's just a step function.
if all(a.dtype in discrete_types for a in (x, y)):
if all(a.dtype in discrete_dtypes for a in (x, y)):
return [x.zeros_like(), y.zeros_like()]
first_part = gz / y
......@@ -2255,7 +2257,7 @@ class Mod(BinaryScalarOp):
def L_op(self, inputs, outputs, gout):
(x, y) = inputs
(gz,) = gout
if outputs[0].type.dtype in discrete_types:
if outputs[0].type in discrete_types:
# The gradient does not flow in if the output is discrete
return [
x.zeros_like(dtype=config.floatX),
......@@ -2619,7 +2621,7 @@ class Sgn(UnaryScalarOp):
(gz,) = gout
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
if rval.type in discrete_types:
rval = rval.astype(config.floatX)
return [rval]
......@@ -2660,7 +2662,7 @@ class Ceil(UnaryScalarOp):
(gz,) = gout
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
if rval.type in discrete_types:
rval = rval.astype(config.floatX)
return [rval]
......@@ -2686,7 +2688,7 @@ class Floor(UnaryScalarOp):
(gz,) = gout
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
if rval.type in discrete_types:
rval = rval.astype(config.floatX)
return [rval]
......@@ -2740,7 +2742,7 @@ class RoundHalfToEven(UnaryScalarOp):
(gz,) = gout
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
if rval.type in discrete_types:
rval = rval.astype(config.floatX)
return [rval]
......@@ -2826,7 +2828,7 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
(gz,) = gout
rval = x.zeros_like()
if rval.type.dtype in discrete_types:
if rval.type in discrete_types:
rval = rval.astype(config.floatX)
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论