fixed unreachable input bug with flag to graph.results_and_orphans

上级 affde532
...@@ -59,6 +59,18 @@ class _test_orphans(unittest.TestCase): ...@@ -59,6 +59,18 @@ class _test_orphans(unittest.TestCase):
op2 = MyOp(op.outputs[0], r5) op2 = MyOp(op.outputs[0], r5)
assert orphans([r1, r2], op2.outputs) == set([r5]) assert orphans([r1, r2], op2.outputs) == set([r5])
def test_1(self):
r1, r2, r5 = MyResult(1), MyResult(2), MyResult(5)
op = MyOp(r1, r2)
op2 = MyOp(op.outputs[0], r5)
try:
ro = results_and_orphans([r1, r2, op2.outputs[0]], op.outputs, True)
self.fail()
except Exception, e:
if e[0] is results_and_orphans.E_unreached:
return
raise
class _test_as_string(unittest.TestCase): class _test_as_string(unittest.TestCase):
......
...@@ -120,9 +120,7 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -120,9 +120,7 @@ class _test_PerformLinker(unittest.TestCase):
a,d = add(x,y), div(x,y) a,d = add(x,y), div(x,y)
e = mul(a,d) e = mul(a,d)
fn = perform_linker(env([x, a], [e])).make_function() fn = perform_linker(env([x, a], [e])).make_function()
#perform linker should have recognized that one input is a function of self.failUnless(fn(1.0,9.0) == 4.5)
#the other one, which makes no sense
self.fail('this graph should not have been compiled')
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x,y,z = inputs()
...@@ -132,6 +130,13 @@ class _test_PerformLinker(unittest.TestCase): ...@@ -132,6 +130,13 @@ class _test_PerformLinker(unittest.TestCase):
fn = perform_linker(env([x, y,r], [e])).make_function() fn = perform_linker(env([x, y,r], [e])).make_function()
self.failUnless(fn(1.0,2.0,4.5) == 7.5) self.failUnless(fn(1.0,2.0,4.5) == 7.5)
def test_disconnected_input_output(self):
x,y,z = inputs()
a = add(x,y)
fn = perform_linker(env([z], [a])).make_function()
self.failUnless(fn(1.0) == 3.0)
self.failUnless(fn(2.0) == 3.0)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -378,6 +378,7 @@ class Env(graph.Graph): ...@@ -378,6 +378,7 @@ class Env(graph.Graph):
listener.on_import(op) listener.on_import(op)
except AbstractFunctionError: except AbstractFunctionError:
pass pass
__import__.E_output = 'op output in Env.inputs'
def __prune_r__(self, results): def __prune_r__(self, results):
for result in set(results): for result in set(results):
......
...@@ -39,7 +39,7 @@ def inputs(o): ...@@ -39,7 +39,7 @@ def inputs(o):
return results return results
def results_and_orphans(i, o): def results_and_orphans(i, o, warn_unreachable_input=False):
""" """
i -> list of input Results i -> list of input Results
o -> list of output Results o -> list of output Results
...@@ -53,9 +53,11 @@ def results_and_orphans(i, o): ...@@ -53,9 +53,11 @@ def results_and_orphans(i, o):
results = set(o) results = set(o)
results.update(i) results.update(i)
incomplete_paths = [] incomplete_paths = []
reached = set()
def helper(r, path): def helper(r, path):
if r in i: if r in i:
reached.add(r)
results.update(path) results.update(path)
elif r.owner is None: elif r.owner is None:
incomplete_paths.append(path) incomplete_paths.append(path)
...@@ -74,7 +76,12 @@ def results_and_orphans(i, o): ...@@ -74,7 +76,12 @@ def results_and_orphans(i, o):
orphans.add(r) orphans.add(r)
break break
if warn_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
return results, orphans return results, orphans
results_and_orphans.E_unreached = 'there were unreachable inputs'
def ops(i, o): def ops(i, o):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论