提交 f1a426d9 authored 作者: James Bergstra's avatar James Bergstra

FIX: scalar.maximum and scalar.minimum casts in grad

上级 56888c31
...@@ -1085,9 +1085,9 @@ class Maximum(BinaryScalarOp): ...@@ -1085,9 +1085,9 @@ class Maximum(BinaryScalarOp):
# max is not defined for complex_types # max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in float_types: if x.type in float_types:
gx = eq(maximum(x, y), x) * gz gx = cast(eq(maximum(x, y), x) * gz, x.type.dtype)
if y.type in float_types: if y.type in float_types:
gy = eq(maximum(x, y), y) * gz gy = cast(eq(maximum(x, y), y) * gz, y.type.dtype)
return (gx, gy) return (gx, gy)
maximum = Maximum(upcast_out, name='maximum') maximum = Maximum(upcast_out, name='maximum')
...@@ -1110,9 +1110,9 @@ class Minimum(BinaryScalarOp): ...@@ -1110,9 +1110,9 @@ class Minimum(BinaryScalarOp):
# max is not defined for complex_types # max is not defined for complex_types
gx, gy = None, None gx, gy = None, None
if x.type in float_types: if x.type in float_types:
gx = eq(minimum(x, y), x) * gz gx = cast(eq(minimum(x, y), x) * gz, x.type.dtype)
if y.type in float_types: if y.type in float_types:
gy = eq(minimum(x, y), y) * gz gy = cast(eq(minimum(x, y), y) * gz, y.type.dtype)
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论