提交 dc84f81b authored 作者: James Bergstra's avatar James Bergstra

local_mul_to_neg messed up dtype when the -1 was also forcing an upcast

上级 dec15a5f
...@@ -2008,9 +2008,12 @@ def local_sum_alloc(node): ...@@ -2008,9 +2008,12 @@ def local_sum_alloc(node):
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_to_neg(node): def local_mul_to_neg(node):
if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0): if node.op == T.mul and N.all(local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
return [-local_mul_canonizer.merge_num_denum(node.inputs[1:], [])] other_prod = local_mul_canonizer.merge_num_denum(node.inputs[1:], [])
else: if other_prod.type == node.outputs[0].type:
return False return [-other_prod]
# else the multiplication is also acting as a cast, so we might as well leave it alone.
# I don't think it's better to turn this into a negation in the wrong type, followed by
# an explicit cast.
register_specialize(local_mul_to_neg) register_specialize(local_mul_to_neg)
@register_specialize @register_specialize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论