提交 ad87e996 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed small bug in Env

上级 6de8c3cc
...@@ -34,6 +34,9 @@ class MyType(Type): ...@@ -34,6 +34,9 @@ class MyType(Type):
def MyResult(name): def MyResult(name):
return Result(MyType(), None, None, name = name) return Result(MyType(), None, None, name = name)
def MyValue(data):
return graph.Value(MyType(), data = data)
class MyOp(Op): class MyOp(Op):
...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase): ...@@ -304,6 +307,33 @@ class _test_all(unittest.TestCase):
g.replace(tv, sx) g.replace(tv, sx)
assert g.consistent() assert g.consistent()
def test_value_repl(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, MyValue("abc"))
assert g.consistent()
def test_value_repl_2(self):
x, y, z = inputs()
sy = sigmoid(y)
e = add_in_place(x, sy)
g = Env([x,y], [e], False)
assert g.consistent()
g.replace(sy, transpose_view(MyValue("abc")))
assert g.consistent()
def test_misc_2(self):
x, y, z = inputs()
tv = transpose_view(x)
e = add_in_place(x, tv)
g = Env([x,y], [e], False)
assert not g.consistent()
g.replace(tv, x)
assert not g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
...@@ -120,6 +120,8 @@ class Env(utils.object2): ...@@ -120,6 +120,8 @@ class Env(utils.object2):
for r in results: for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r) raise TypeError("Undeclared input", r)
if not getattr(r, 'env', None) is self:
self.__setup_r__(r)
self.results.add(r) self.results.add(r)
def __import__(self, node, check = True): def __import__(self, node, check = True):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论