提交 616b7f60 authored 作者: James Bergstra's avatar James Bergstra

bugfix in mul canonizer regarding int div

上级 9c17814c
......@@ -452,8 +452,15 @@ class Canonizer(gof.LocalOptimizer):
a / (b / c) -> ([a, c], [b])
log(x) -> ([log(x)], [])
x**y -> ([x**y], [])
"""
# This function is recursive.
# The idea is that there is a get_num_denum recursion in which the internal ops are all
# one of (main, inverse, reciprocal, DimShuffle) and the internal data nodes all have
# the dtype of the 'input' argument. The leaf-Variables of the graph covered by the
# recursion may be of any Variable type.
if input.owner is None or input.owner.op not in [self.main, self.inverse, self.reciprocal]:
if input.owner and isinstance(input.owner.op, T.DimShuffle):
# If input is a DimShuffle of some input which does something like this:
......@@ -489,8 +496,13 @@ class Canonizer(gof.LocalOptimizer):
denum = []
parent = input.owner
# We get the (num, denum) pairs for each input
pairs = [self.get_num_denum(input) for input in parent.inputs]
# We get the (num, denum) pairs for each input (input2) *to* input that has the same
# type as this one.
pairs = [self.get_num_denum(input2)
if (getattr(input2.type, 'dtype', None) == input.type.dtype)
else ([input2], [])
for input2 in parent.inputs ]
if parent.op == self.main:
# If we have main(x, y), numx, denumx, numy and denumy
......@@ -683,20 +695,6 @@ class Canonizer(gof.LocalOptimizer):
out = node.outputs[0]
assert len(node.outputs) == 1
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
# inverse operating on an inverse, we can make it so that only
# one inverse is used, so we'll reorganize that.
iops = set(input.owner.op for input in inputs if input.owner)
reorg = False
if op == self.main:
reorg = len(iops.intersection([self.main, self.inverse, self.reciprocal])) != 0
elif op == self.inverse:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
elif op == self.reciprocal:
reorg = len(iops.intersection([self.inverse, self.reciprocal])) != 0
# Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0])
......@@ -706,9 +704,8 @@ class Canonizer(gof.LocalOptimizer):
def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in zip(x, y))
if not reorg and same(orig_num, num) and same(orig_denum, denum):
if same(orig_num, num) and same(orig_denum, denum):
# We return False if there are no changes
# TODO: what's the purpose of reorg? isn't same() sufficient?
return False
new = self.merge_num_denum(num, denum)
......
......@@ -177,6 +177,14 @@ class test_canonize(unittest.TestCase):
mul_canonizer.optimize(g)
gof.TopoOptimizer(gof.LocalOptGroup(local_fill_cut, local_fill_lift), order = 'out_to_in').optimize(g)
print pprint(g.outputs[0])
def test_mixeddiv():
"""Test that int division is preserved"""
i = iscalar()
d = dscalar()
assert 0 == function([i,d], d*(i/(i+1)))(3, 1.0)
# def test_plusmin(self):
# x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论