提交 d92eb12b authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixes #128 (DestroyHandler bug)

上级 bf69ab87
...@@ -5,7 +5,7 @@ from type import Type ...@@ -5,7 +5,7 @@ from type import Type
import graph import graph
from graph import Result, Apply from graph import Result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import *
from ext import * from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
...@@ -13,6 +13,9 @@ from toolbox import ReplaceValidate ...@@ -13,6 +13,9 @@ from toolbox import ReplaceValidate
from copy import copy from copy import copy
PatternOptimizer = lambda p1, p2, ign=True: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
OpSubOptimizer = lambda op1, op2, fail=keep_going, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail)
def as_result(x): def as_result(x):
assert isinstance(x, Result) assert isinstance(x, Result)
...@@ -34,11 +37,13 @@ def MyResult(name): ...@@ -34,11 +37,13 @@ def MyResult(name):
class MyOp(Op): class MyOp(Op):
def __init__(self, nin, name, vmap = {}, dmap = {}): def __init__(self, nin, name, vmap = {}, dmap = {}, nout = 1, tolerate_same = []):
self.nin = nin self.nin = nin
self.nout = nout
self.name = name self.name = name
self.destroy_map = dmap self.destroy_map = dmap
self.view_map = vmap self.view_map = vmap
self.tolerate_same = tolerate_same
def make_node(self, *inputs): def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
...@@ -46,7 +51,7 @@ class MyOp(Op): ...@@ -46,7 +51,7 @@ class MyOp(Op):
for input in inputs: for input in inputs:
if not isinstance(input.type, MyType): if not isinstance(input.type, MyType):
raise Exception("Error 1") raise Exception("Error 1")
outputs = [MyResult(self.name + "_R")] outputs = [MyResult(self.name + "_R") for i in xrange(self.nout)]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __str__(self): def __str__(self):
...@@ -57,6 +62,7 @@ sigmoid = MyOp(1, 'Sigmoid') ...@@ -57,6 +62,7 @@ sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]}) transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]})
add = MyOp(2, 'Add') add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]}) add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]})
add_in_place_2 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, tolerate_same = [(0, 1)])
dot = MyOp(2, 'Dot') dot = MyOp(2, 'Dot')
...@@ -81,8 +87,8 @@ class FailureWatch: ...@@ -81,8 +87,8 @@ class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures # when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures
def __init__(self): def __init__(self):
self.failures = 0 self.failures = 0
def __call__(self, op1, op2, exception): def __call__(self, exc, nav, pairs):
assert isinstance(exception, InconsistencyError) assert isinstance(exc, InconsistencyError)
self.failures += 1 self.failures += 1
...@@ -257,11 +263,46 @@ class _test_all(unittest.TestCase): ...@@ -257,11 +263,46 @@ class _test_all(unittest.TestCase):
g.replace(e0, new_e0) g.replace(e0, new_e0)
assert g.consistent() assert g.consistent()
# def test_aliased_inputs(self): def test_aliased_inputs(self):
# x, y, z = inputs() x, y, z = inputs()
# e = add_in_place(x, transpose_view(x)) e = add_in_place(x, x)
# g = Env([x], [e], False) g = Env([x], [e], False)
# assert not g.consistent() assert not g.consistent()
def test_aliased_inputs2(self):
x, y, z = inputs()
e = add_in_place(x, transpose_view(x))
g = Env([x], [e], False)
assert not g.consistent()
def test_aliased_inputs_tolerate(self):
x, y, z = inputs()
e = add_in_place_2(x, x)
g = Env([x], [e], False)
assert g.consistent()
def test_aliased_inputs_tolerate2(self):
x, y, z = inputs()
e = add_in_place_2(x, transpose_view(x))
g = Env([x], [e], False)
assert not g.consistent()
def test_aliased_inputs_replacement(self):
x, y, z = inputs()
tv = transpose_view(x)
tvv = transpose_view(tv)
sx = sigmoid(x)
e = add_in_place(x, tv)
g = Env([x, y], [e], False)
assert not g.consistent()
g.replace(tv, sx)
assert g.consistent()
g.replace(sx, tv)
assert not g.consistent()
g.replace(tv, tvv)
assert not g.consistent()
g.replace(tv, sx)
assert g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -38,7 +38,6 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -38,7 +38,6 @@ class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandlerHelper(toolbox.Bookkeeper): class DestroyHandlerHelper(toolbox.Bookkeeper):
""" """
This feature ensures that an env represents a consistent data flow This feature ensures that an env represents a consistent data flow
...@@ -168,7 +167,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper): ...@@ -168,7 +167,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
is returned. is returned.
""" """
views = self.__views__(r) views = self.__views__(r)
rval = [] # set() rval = list(r.owner.outputs) if r.owner else [] # set()
for view in views: for view in views:
for node, i in view.clients: #self.env.clients(view): for node, i in view.clients: #self.env.clients(view):
if node != 'output': if node != 'output':
...@@ -183,21 +182,37 @@ class DestroyHandlerHelper(toolbox.Bookkeeper): ...@@ -183,21 +182,37 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
rval = set() rval = set()
if op is None: if op is None:
return rval return rval
keep_going = False dmap = getattr(op.op, 'destroy_map', {})
for input in op.inputs: dinputs = reduce(list.__add__, dmap.values(), [])
d_found = {}
nd_found = {}
for i, input in enumerate(op.inputs):
# Get the basic result the input is a view of. # Get the basic result the input is a view of.
foundation = self.__path__(input)[0] path = self.__path__(input)
foundation = path[0]
destroyers = self.destroyers.get(foundation, set()) destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
# Is this op destroying the foundation? If yes, # Is this op destroying the foundation? If yes,
# all users of the foundation must be computed before # all users of the foundation must be computed before
# we overwrite its contents. # we overwrite its contents.
if op in destroyers: if op in destroyers and i in dinputs:
d_found[foundation] = i
users = self.__users__(foundation) users = self.__users__(foundation)
rval.update(users) rval.update(users)
else:
nd_found[foundation] = i
rval.update(op.inputs) # obviously rval.update(op.inputs) # obviously
intersection = set(d_found.keys()).intersection(set(nd_found.keys()))
if not intersection:
rval.difference_update(op.outputs) # this op's outputs will always be in the users rval.difference_update(op.outputs) # this op's outputs will always be in the users
else:
allowed = getattr(op.op, 'tolerate_same', [])
for item in intersection:
i, j = d_found[item], nd_found[item]
pair = i, j
if not (op.inputs[i] is op.inputs[j] and (pair in allowed or tuple(reversed(pair)) in allowed)):
break
else:
rval.difference_update(op.outputs)
return rval return rval
def __detect_cycles_helper__(self, r, seq): def __detect_cycles_helper__(self, r, seq):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论