提交 9ac18dc1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing properties to copied Function objects

上级 104dc037
......@@ -785,6 +785,8 @@ class Function:
f_cpy.finder[swap[in_ori.variable]] = container
in_cpy.variable = swap[in_ori.variable]
f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy
......
......@@ -299,7 +299,7 @@ class TestFunction:
t()
def test_copy(self):
a = scalar() # the a is for 'anonymous' (un-named).
a = scalar()
x, s = scalars("xs")
f = function(
......@@ -312,26 +312,34 @@ class TestFunction:
)
g = copy.copy(f)
# if they both return, assume that they return equivalent things.
assert f.unpack_single == g.unpack_single
assert f.trust_input == g.trust_input
assert g.container[x].storage is not f.container[x].storage
assert g.container[a].storage is not f.container[a].storage
assert g.container[s].storage is not f.container[s].storage
assert g.value[a] is f.value[a] # should not have been copied
assert (
g.value[s] is not f.value[s]
) # should have been copied because it is mutable.
assert not (g.value[s] != f.value[s]).any() # its contents should be identical
# Should not have been copied
assert g.value[a] is f.value[a]
assert f(2, 1) == g(
2
) # they should be in sync, default value should be copied.
assert f(2, 1) == g(
2
) # they should be in sync, default value should be copied.
f(1, 2) # put them out of sync
assert f(1, 2) != g(1, 2) # they should not be equal anymore.
# Should have been copied because it is mutable
assert g.value[s] is not f.value[s]
# Their contents should be equal, though
assert np.array_equal(g.value[s], f.value[s])
# They should be in sync, default value should be copied
assert np.array_equal(f(2, 1), g(2))
# They should be in sync, default value should be copied
assert np.array_equal(f(2, 1), g(2))
# Put them out of sync
f(1, 2)
# They should not be equal anymore
assert not np.array_equal(f(1, 2), g(1, 2))
def test_copy_share_memory(self):
x = fscalar("x")
......@@ -478,9 +486,9 @@ class TestFunction:
ori = function([x], out, mode=mode, updates={z: z * 2})
cpy = ori.copy(delete_updates=True)
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1)[0] == 4
assert cpy(1) == 4
assert cpy(1) == 4
assert cpy(1) == 4
# Test if unused implicit and explicit inputs from delete_updates
# are ignored as intended.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论