提交 74b1abaa authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix broken dtype checks in aesara.scalar.basic

上级 7d6dfcb4
...@@ -757,6 +757,8 @@ discrete_types = (bool,) + integer_types ...@@ -757,6 +757,8 @@ discrete_types = (bool,) + integer_types
continuous_types = float_types + complex_types continuous_types = float_types + complex_types
all_types = discrete_types + continuous_types all_types = discrete_types + continuous_types
discrete_dtypes = tuple(t.dtype for t in discrete_types)
class _scalar_py_operators: class _scalar_py_operators:
# So that we can simplify checking code when we have a mixture of Scalar # So that we can simplify checking code when we have a mixture of Scalar
...@@ -1596,7 +1598,7 @@ class Switch(ScalarOp): ...@@ -1596,7 +1598,7 @@ class Switch(ScalarOp):
first_part = switch(cond, gz, 0.0) first_part = switch(cond, gz, 0.0)
second_part = switch(cond, 0.0, gz) second_part = switch(cond, 0.0, gz)
if outputs[0].type.dtype in discrete_types: if outputs[0].type in discrete_types:
first_part = 0.0 first_part = 0.0
second_part = 0.0 second_part = 0.0
...@@ -1991,7 +1993,7 @@ class TrueDiv(BinaryScalarOp): ...@@ -1991,7 +1993,7 @@ class TrueDiv(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
x = np.asarray(x) x = np.asarray(x)
y = np.asarray(y) y = np.asarray(y)
if all(a.dtype in discrete_types for a in (x, y)): if all(a.dtype in discrete_dtypes for a in (x, y)):
return np.sctypeDict[config.floatX](float(x) / y) return np.sctypeDict[config.floatX](float(x) / y)
else: else:
return x / y return x / y
...@@ -2030,7 +2032,7 @@ class TrueDiv(BinaryScalarOp): ...@@ -2030,7 +2032,7 @@ class TrueDiv(BinaryScalarOp):
# This is different from it not being connected # This is different from it not being connected
# to the output; x/y is still a function of x # to the output; x/y is still a function of x
# and y; it's just a step function. # and y; it's just a step function.
if all(a.dtype in discrete_types for a in (x, y)): if all(a.dtype in discrete_dtypes for a in (x, y)):
return [x.zeros_like(), y.zeros_like()] return [x.zeros_like(), y.zeros_like()]
first_part = gz / y first_part = gz / y
...@@ -2255,7 +2257,7 @@ class Mod(BinaryScalarOp): ...@@ -2255,7 +2257,7 @@ class Mod(BinaryScalarOp):
def L_op(self, inputs, outputs, gout): def L_op(self, inputs, outputs, gout):
(x, y) = inputs (x, y) = inputs
(gz,) = gout (gz,) = gout
if outputs[0].type.dtype in discrete_types: if outputs[0].type in discrete_types:
# The gradient does not flow in if the output is discrete # The gradient does not flow in if the output is discrete
return [ return [
x.zeros_like(dtype=config.floatX), x.zeros_like(dtype=config.floatX),
...@@ -2619,7 +2621,7 @@ class Sgn(UnaryScalarOp): ...@@ -2619,7 +2621,7 @@ class Sgn(UnaryScalarOp):
(gz,) = gout (gz,) = gout
rval = x.zeros_like() rval = x.zeros_like()
if rval.type.dtype in discrete_types: if rval.type in discrete_types:
rval = rval.astype(config.floatX) rval = rval.astype(config.floatX)
return [rval] return [rval]
...@@ -2660,7 +2662,7 @@ class Ceil(UnaryScalarOp): ...@@ -2660,7 +2662,7 @@ class Ceil(UnaryScalarOp):
(gz,) = gout (gz,) = gout
rval = x.zeros_like() rval = x.zeros_like()
if rval.type.dtype in discrete_types: if rval.type in discrete_types:
rval = rval.astype(config.floatX) rval = rval.astype(config.floatX)
return [rval] return [rval]
...@@ -2686,7 +2688,7 @@ class Floor(UnaryScalarOp): ...@@ -2686,7 +2688,7 @@ class Floor(UnaryScalarOp):
(gz,) = gout (gz,) = gout
rval = x.zeros_like() rval = x.zeros_like()
if rval.type.dtype in discrete_types: if rval.type in discrete_types:
rval = rval.astype(config.floatX) rval = rval.astype(config.floatX)
return [rval] return [rval]
...@@ -2740,7 +2742,7 @@ class RoundHalfToEven(UnaryScalarOp): ...@@ -2740,7 +2742,7 @@ class RoundHalfToEven(UnaryScalarOp):
(gz,) = gout (gz,) = gout
rval = x.zeros_like() rval = x.zeros_like()
if rval.type.dtype in discrete_types: if rval.type in discrete_types:
rval = rval.astype(config.floatX) rval = rval.astype(config.floatX)
return [rval] return [rval]
...@@ -2826,7 +2828,7 @@ class RoundHalfAwayFromZero(UnaryScalarOp): ...@@ -2826,7 +2828,7 @@ class RoundHalfAwayFromZero(UnaryScalarOp):
(gz,) = gout (gz,) = gout
rval = x.zeros_like() rval = x.zeros_like()
if rval.type.dtype in discrete_types: if rval.type in discrete_types:
rval = rval.astype(config.floatX) rval = rval.astype(config.floatX)
return [rval] return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论