提交 e8c2782f authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Remove unnecessary type checks in `local_exp_log*` rewrites

上级 4ab08dea
......@@ -285,7 +285,7 @@ def local_exp_log(fgraph, node):
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
# Exp may have casted integer input to float
# Exp may have cast integer input to float
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
......@@ -293,11 +293,7 @@ def local_exp_log(fgraph, node):
# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = add(1, exp(x))
if not old_out.type.is_super(new_out.type):
return
return [new_out]
return [add(1, exp(x))]
@register_specialize
......@@ -319,8 +315,6 @@ def local_exp_log_nan_switch(fgraph, node):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
if not old_out.type.is_super(new_out.type):
return
return [new_out]
# Case for exp(log1p(x))
......@@ -328,8 +322,6 @@ def local_exp_log_nan_switch(fgraph, node):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
if not old_out.type.is_super(new_out.type):
return
return [new_out]
# Case for exp(log1mexp(x))
......@@ -337,8 +329,6 @@ def local_exp_log_nan_switch(fgraph, node):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
if not old_out.type.is_super(new_out.type):
return
return [new_out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论