提交 9ac18dc1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing properties to copied Function objects

上级 104dc037
...@@ -785,6 +785,8 @@ class Function: ...@@ -785,6 +785,8 @@ class Function:
f_cpy.finder[swap[in_ori.variable]] = container f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable] in_cpy.variable = swap[in_ori.variable]
f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single
f_cpy.name = name f_cpy.name = name
f_cpy.maker.fgraph.name = name f_cpy.maker.fgraph.name = name
return f_cpy return f_cpy
......
...@@ -299,7 +299,7 @@ class TestFunction: ...@@ -299,7 +299,7 @@ class TestFunction:
t() t()
def test_copy(self): def test_copy(self):
a = scalar() # the a is for 'anonymous' (un-named). a = scalar()
x, s = scalars("xs") x, s = scalars("xs")
f = function( f = function(
...@@ -312,26 +312,34 @@ class TestFunction: ...@@ -312,26 +312,34 @@ class TestFunction:
) )
g = copy.copy(f) g = copy.copy(f)
# if they both return, assume that they return equivalent things.
assert f.unpack_single == g.unpack_single
assert f.trust_input == g.trust_input
assert g.container[x].storage is not f.container[x].storage assert g.container[x].storage is not f.container[x].storage
assert g.container[a].storage is not f.container[a].storage assert g.container[a].storage is not f.container[a].storage
assert g.container[s].storage is not f.container[s].storage assert g.container[s].storage is not f.container[s].storage
assert g.value[a] is f.value[a] # should not have been copied # Should not have been copied
assert ( assert g.value[a] is f.value[a]
g.value[s] is not f.value[s]
) # should have been copied because it is mutable.
assert not (g.value[s] != f.value[s]).any() # its contents should be identical
assert f(2, 1) == g( # Should have been copied because it is mutable
2 assert g.value[s] is not f.value[s]
) # they should be in sync, default value should be copied.
assert f(2, 1) == g( # Their contents should be equal, though
2 assert np.array_equal(g.value[s], f.value[s])
) # they should be in sync, default value should be copied.
f(1, 2) # put them out of sync # They should be in sync, default value should be copied
assert f(1, 2) != g(1, 2) # they should not be equal anymore. assert np.array_equal(f(2, 1), g(2))
# They should be in sync, default value should be copied
assert np.array_equal(f(2, 1), g(2))
# Put them out of sync
f(1, 2)
# They should not be equal anymore
assert not np.array_equal(f(1, 2), g(1, 2))
def test_copy_share_memory(self): def test_copy_share_memory(self):
x = fscalar("x") x = fscalar("x")
...@@ -478,9 +486,9 @@ class TestFunction: ...@@ -478,9 +486,9 @@ class TestFunction:
ori = function([x], out, mode=mode, updates={z: z * 2}) ori = function([x], out, mode=mode, updates={z: z * 2})
cpy = ori.copy(delete_updates=True) cpy = ori.copy(delete_updates=True)
assert cpy(1)[0] == 4 assert cpy(1) == 4
assert cpy(1)[0] == 4 assert cpy(1) == 4
assert cpy(1)[0] == 4 assert cpy(1) == 4
# Test if unused implicit and explicit inputs from delete_updates # Test if unused implicit and explicit inputs from delete_updates
# are ignored as intended. # are ignored as intended.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论