提交 c378c628 authored 作者: James Bergstra's avatar James Bergstra

fixes bug in mul canonizer

上级 0fe44b63
......@@ -638,6 +638,7 @@ class Canonizer(gof.LocalOptimizer):
# Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum)
out_type = self.merge_num_denum(orig_num, orig_denum).type
# Lists representing the *constant* elements of num and denum
numct, denumct = [], []
......@@ -660,14 +661,14 @@ class Canonizer(gof.LocalOptimizer):
# This will calculate either:
# [inverse(main(*numct), main(*denumct))]
# [] - if inverse(main(*numct), main(*denumct)) is the neutral element
ct = self.calculate(numct, denumct, aslist = True)
ct = self.calculate(numct, denumct, aslist = True, out_type=out_type)
else:
# This happens if we don't allow the reciprocal and the
# numerator is empty. That means we will need to represent
# reciprocal(x) like inverse(neutral_element, x) so
# we can't allow ct == []
# TODO: why is this branch needed when merge_num_denum does it for us?
ct = [self.calculate(numct, denumct, aslist = False)]
ct = [self.calculate(numct, denumct, aslist = False, out_type=out_type)]
# TODO: why are we not wrapping ct in a gof.Constant right now?
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and N.all(ct == self.get_constant(orig_num[0])):
......@@ -725,13 +726,17 @@ class Canonizer(gof.LocalOptimizer):
return getattr(self, 'name', 'Canonizer(%s, %s, %s)' % (self.main, self.inverse, self.reciprocal))
def mul_calculate(num, denum, aslist=False):
def mul_calculate(num, denum, aslist=False, out_type=None):
if not num and not denum:
# Smallest 1 possible.
return [] if aslist else N.int8(1)
# Make sure we do not accidently upcast data types.
first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1)
if out_type is None:
# TODO: remove this error-causing heuristic
first = num[0] if num else denum[0]
one = N.asarray(first).dtype.type(1)
else:
one = N.asarray(1, dtype=out_type.dtype)
v = reduce(N.multiply, num, one) / reduce(N.multiply, denum, one)
if aslist:
if N.all(v == 1):
......@@ -898,8 +903,10 @@ mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, lo
def add_calculate(num, denum, aslist = False):
v = reduce(N.add, num, 0.0) - reduce(N.add, denum, 0.0)
def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar
zero = 0.0 if out_type is None else N.asarray(0, dtype=out_type.dtype)
v = reduce(N.add, num, zero) - reduce(N.add, denum, zero)
if aslist:
if N.all(v == 0):
return []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论