提交 d92eb12b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixes #128 (DestroyHandler bug)

上级 bf69ab87
......@@ -5,7 +5,7 @@ from type import Type
import graph
from graph import Result, Apply
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from opt import *
from ext import *
from env import Env, InconsistencyError
......@@ -13,6 +13,9 @@ from toolbox import ReplaceValidate
from copy import copy
PatternOptimizer = lambda p1, p2, ign=True: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
OpSubOptimizer = lambda op1, op2, fail=keep_going, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail)
def as_result(x):
assert isinstance(x, Result)
......@@ -34,11 +37,13 @@ def MyResult(name):
class MyOp(Op):
def __init__(self, nin, name, vmap = {}, dmap = {}):
def __init__(self, nin, name, vmap = {}, dmap = {}, nout = 1, tolerate_same = []):
self.nin = nin
self.nout = nout
self.name = name
self.destroy_map = dmap
self.view_map = vmap
self.tolerate_same = tolerate_same
def make_node(self, *inputs):
assert len(inputs) == self.nin
......@@ -46,7 +51,7 @@ class MyOp(Op):
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyResult(self.name + "_R")]
outputs = [MyResult(self.name + "_R") for i in xrange(self.nout)]
return Apply(self, inputs, outputs)
def __str__(self):
......@@ -57,6 +62,7 @@ sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]})
add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]})
add_in_place_2 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, tolerate_same = [(0, 1)])
dot = MyOp(2, 'Dot')
......@@ -81,8 +87,8 @@ class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures
def __init__(self):
self.failures = 0
def __call__(self, op1, op2, exception):
assert isinstance(exception, InconsistencyError)
def __call__(self, exc, nav, pairs):
assert isinstance(exc, InconsistencyError)
self.failures += 1
......@@ -257,11 +263,46 @@ class _test_all(unittest.TestCase):
g.replace(e0, new_e0)
assert g.consistent()
# def test_aliased_inputs(self):
# x, y, z = inputs()
# e = add_in_place(x, transpose_view(x))
# g = Env([x], [e], False)
# assert not g.consistent()
def test_aliased_inputs(self):
x, y, z = inputs()
e = add_in_place(x, x)
g = Env([x], [e], False)
assert not g.consistent()
def test_aliased_inputs2(self):
x, y, z = inputs()
e = add_in_place(x, transpose_view(x))
g = Env([x], [e], False)
assert not g.consistent()
def test_aliased_inputs_tolerate(self):
x, y, z = inputs()
e = add_in_place_2(x, x)
g = Env([x], [e], False)
assert g.consistent()
def test_aliased_inputs_tolerate2(self):
x, y, z = inputs()
e = add_in_place_2(x, transpose_view(x))
g = Env([x], [e], False)
assert not g.consistent()
def test_aliased_inputs_replacement(self):
x, y, z = inputs()
tv = transpose_view(x)
tvv = transpose_view(tv)
sx = sigmoid(x)
e = add_in_place(x, tv)
g = Env([x, y], [e], False)
assert not g.consistent()
g.replace(tv, sx)
assert g.consistent()
g.replace(sx, tv)
assert not g.consistent()
g.replace(tv, tvv)
assert not g.consistent()
g.replace(tv, sx)
assert g.consistent()
if __name__ == '__main__':
......
......@@ -38,7 +38,6 @@ class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandlerHelper(toolbox.Bookkeeper):
"""
This feature ensures that an env represents a consistent data flow
......@@ -168,7 +167,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
is returned.
"""
views = self.__views__(r)
rval = [] # set()
rval = list(r.owner.outputs) if r.owner else [] # set()
for view in views:
for node, i in view.clients: #self.env.clients(view):
if node != 'output':
......@@ -183,21 +182,37 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
dmap = getattr(op.op, 'destroy_map', {})
dinputs = reduce(list.__add__, dmap.values(), [])
d_found = {}
nd_found = {}
for i, input in enumerate(op.inputs):
# Get the basic result the input is a view of.
foundation = self.__path__(input)[0]
path = self.__path__(input)
foundation = path[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
# Is this op destroying the foundation? If yes,
# all users of the foundation must be computed before
# we overwrite its contents.
if op in destroyers:
if op in destroyers and i in dinputs:
d_found[foundation] = i
users = self.__users__(foundation)
rval.update(users)
else:
nd_found[foundation] = i
rval.update(op.inputs) # obviously
rval.difference_update(op.outputs) # this op's outputs will always be in the users
intersection = set(d_found.keys()).intersection(set(nd_found.keys()))
if not intersection:
rval.difference_update(op.outputs) # this op's outputs will always be in the users
else:
allowed = getattr(op.op, 'tolerate_same', [])
for item in intersection:
i, j = d_found[item], nd_found[item]
pair = i, j
if not (op.inputs[i] is op.inputs[j] and (pair in allowed or tuple(reversed(pair)) in allowed)):
break
else:
rval.difference_update(op.outputs)
return rval
def __detect_cycles_helper__(self, r, seq):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论