提交 973528ca authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed mul_calculate to avoid accidentally upcasting data types

上级 b62b7699
...@@ -686,7 +686,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -686,7 +686,7 @@ class Canonizer(gof.LocalOptimizer):
op = node.op op = node.op
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
inputs = node.inputs inputs = node.inputs
out = node.outputs[0] out = node.outputs[0]
assert len(node.outputs) == 1 assert len(node.outputs) == 1
...@@ -725,8 +725,14 @@ class Canonizer(gof.LocalOptimizer): ...@@ -725,8 +725,14 @@ class Canonizer(gof.LocalOptimizer):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal)) return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal))
def mul_calculate(num, denum, aslist = False): def mul_calculate(num, denum, aslist=False):
v = reduce(N.multiply, num, 1.0) / reduce(N.multiply, denum, 1.0) if not num and not denum:
# Smallest 1 possible.
return [] if asList else N.int8(1)
# Make sure we do not accidently upcast data types.
first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1)
v = reduce(N.multiply, num, one) / reduce(N.multiply, denum, one)
if aslist: if aslist:
if N.all(v == 1): if N.all(v == 1):
return [] return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论