fixed bug in PatternOptimizer and added/documented tests

上级 9d425649
......@@ -64,14 +64,13 @@ def inputs():
return x, y, z
def env(inputs, outputs, validate = True):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = [EquivTool], consistency_check = validate)
class _test_PatternOptimizer(unittest.TestCase):
def test_0(self):
def test_replace_output(self):
# replacing the whole graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
......@@ -79,15 +78,34 @@ class _test_PatternOptimizer(unittest.TestCase):
(Op4, '3', '2')).optimize(g)
assert str(g) == "[Op4(z, y)]"
def test_1(self):
def test_nested_out_pattern(self):
x, y, z = inputs()
e = op1(op2(x, y), z)
e = op1(x, y)
g = env([x, y, z], [e])
PatternOptimizer((Op1, '1', '2'),
(Op4, (Op1, '1'), (Op2, '2'), (Op3, '1', '2'))).optimize(g)
assert str(g) == "[Op4(Op1(x), Op2(y), Op3(x, y))]"
def test_unification_1(self):
x, y, z = inputs()
e = op1(op2(x, x), z) # the arguments to op2 are the same
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'),
PatternOptimizer((Op1, (Op2, '1', '1'), '2'), # they are the same in the pattern
(Op4, '2', '1')).optimize(g)
assert str(g) != "[Op4(z, y)]"
# So the replacement should occur
assert str(g) == "[Op4(z, x)]"
def test_2(self):
def test_unification_2(self):
x, y, z = inputs()
e = op1(op2(x, y), z) # the arguments to op2 are different
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op2, '1', '1'), '2'), # they are the same in the pattern
(Op4, '2', '1')).optimize(g)
# The replacement should NOT occur
assert str(g) == "[Op1(Op2(x, y), z)]"
def test_replace_subgraph(self):
# replacing inside the graph
x, y, z = inputs()
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
......@@ -95,7 +113,18 @@ class _test_PatternOptimizer(unittest.TestCase):
(Op1, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op1(y, x), z)]"
def test_3(self):
def test_no_recurse(self):
# if the out pattern is an acceptable in pattern,
# it should do the replacement and stop
x, y, z = inputs()
e = op1(op2(x, y), z)
g = env([x, y, z], [e])
PatternOptimizer((Op2, '1', '2'),
(Op2, '2', '1')).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]"
def test_multiple(self):
# it should replace all occurrences of the pattern
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = env([x, y, z], [e])
......@@ -103,7 +132,9 @@ class _test_PatternOptimizer(unittest.TestCase):
(Op4, '1')).optimize(g)
assert str(g) == "[Op1(Op4(x), Op4(x), Op4(y))]"
def test_4(self):
def test_nested_even(self):
# regardless of the order in which we optimize, this
# should work
x, y, z = inputs()
e = op1(op1(op1(op1(x))))
g = env([x, y, z], [e])
......@@ -111,7 +142,7 @@ class _test_PatternOptimizer(unittest.TestCase):
'1').optimize(g)
assert str(g) == "[x]"
def test_5(self):
def test_nested_odd(self):
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e])
......@@ -119,7 +150,27 @@ class _test_PatternOptimizer(unittest.TestCase):
'1').optimize(g)
assert str(g) == "[Op1(x)]"
def test_6(self):
def test_expand(self):
x, y, z = inputs()
e = op1(op1(op1(x)))
g = env([x, y, z], [e])
PatternOptimizer((Op1, '1'),
(Op2, (Op1, '1'))).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
def test_ambiguous(self):
# this test is known to fail most of the time
# the reason is that PatternOptimizer doesn't go through
# the ops in topological order. The order is random and
# it does not visit ops that it creates.
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e])
PatternOptimizer((Op1, (Op1, '1')),
(Op1, '1')).optimize(g)
assert str(g) == "[Op1(x)]"
def test_constant_unification(self):
x, y, z = inputs()
x.constant = True
x.value = 2
......@@ -134,14 +185,14 @@ class _test_PatternOptimizer(unittest.TestCase):
class _test_OpSubOptimizer(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
x, y, z = inputs()
e = op1(op1(op1(op1(op1(x)))))
g = env([x, y, z], [e])
OpSubOptimizer(Op1, Op2).optimize(g)
assert str(g) == "[Op2(Op2(Op2(Op2(Op2(x)))))]"
def test_1(self):
def test_straightforward_2(self):
x, y, z = inputs()
e = op1(op2(x), op3(y), op4(z))
g = env([x, y, z], [e])
......@@ -151,14 +202,14 @@ class _test_OpSubOptimizer(unittest.TestCase):
class _test_MergeOptimizer(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
x, y, z = inputs()
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, Op2(x, z))]"
def test_1(self):
def test_constant_merging(self):
x, y, z = inputs()
y.data = 2
y.constant = True
......@@ -167,8 +218,44 @@ class _test_MergeOptimizer(unittest.TestCase):
e = op1(op2(x, y), op2(x, y), op2(x, z))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or str(g) == "[Op1(*1 -> Op2(x, z), *1, *1)]"
strg = str(g)
assert strg == "[Op1(*1 -> Op2(x, y), *1, *1)]" \
or strg == "[Op1(*1 -> Op2(x, z), *1, *1)]"
def test_deep_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z)))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(*1 -> Op3(Op2(x, y), z), Op4(*1))]"
def test_no_merge(self):
x, y, z = inputs()
e = op1(op3(op2(x, y)), op3(op2(y, x)))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
assert str(g) == "[Op1(Op3(Op2(x, y)), Op3(Op2(y, x)))]"
def test_merge_outputs(self):
x, y, z = inputs()
e1 = op3(op2(x, y))
e2 = op3(op2(x, y))
g = env([x, y, z], [e1, e2])
MergeOptimizer().optimize(g)
assert str(g) == "[*1 -> Op3(Op2(x, y)), *1]"
def test_multiple_merges(self):
x, y, z = inputs()
e1 = op1(x, y)
e2 = op2(op3(x), y, z)
e = op1(e1, op4(e2, e1), op1(e2))
g = env([x, y, z], [e])
MergeOptimizer().optimize(g)
strg = str(g)
# note: graph.as_string can only produce the following two possibilities, but if
# the implementation was to change there are 6 other acceptable answers.
assert strg == "[Op1(*1 -> Op1(x, y), Op4(*2 -> Op2(Op3(x), y, z), *1), Op1(*2))]" \
or strg == "[Op1(*2 -> Op1(x, y), Op4(*1 -> Op2(Op3(x), y, z), *2), Op1(*1))]"
class _test_ConstantFinder(unittest.TestCase):
......
......@@ -261,7 +261,8 @@ class PatternOptimizer(OpSpecificOptimizer):
def build(pattern, u):
if isinstance(pattern, (list, tuple)):
return pattern[0](*[build(p, u) for p in pattern[1:]])
args = [build(p, u) for p in pattern[1:]]
return pattern[0](*args).out
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
......@@ -272,9 +273,8 @@ class PatternOptimizer(OpSpecificOptimizer):
try:
# note: only replaces the default 'out' port if it exists
p = self.out_pattern
new = 'unassigned'
new = build(p, u)
if not isinstance(p, str):
new = new.out
env.replace(op.out, new)
except Exception, e:
if self.failure_callback is not None:
......@@ -349,12 +349,11 @@ class MergeOptimizer(Optimizer):
cid[op] = op_cid
inv_cid[op_cid] = op
for i, output in enumerate(op.outputs):
ref = (i, op_cid)
ref = id(output) # (i, op_cid)
cid[output] = ref
inv_cid[ref] = output
else:
for output, other_output in zip(op.outputs, dup.outputs):
#print "replacing: %s %s" % (repr(output.owner), repr(other_output.owner))
env.replace(output, other_output)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论