提交 202eed7f authored 作者: nouiz's avatar nouiz

Merge pull request #862 from lamblin/fix_cvm_updates

Fix corner case in cvm updates, add test. Fix wrong results that could end up with wrong shape.
...@@ -618,6 +618,42 @@ class Test_pfunc(unittest.TestCase): ...@@ -618,6 +618,42 @@ class Test_pfunc(unittest.TestCase):
self.assertRaises(theano.compile.UnusedInputError, self.assertRaises(theano.compile.UnusedInputError,
theano.function, [x, x, x], x) theano.function, [x, x, x], x)
def test_update_same(self):
# There was a bug in CVM, triggered when a shared variable
# was its own update expression.
a = shared(1., 'a')
b = shared(numpy.ones((2, 3)), 'b')
# The order of the variables is not determined, so we try
# both shared variables.
f = theano.function([], [], updates={a: a, b: (2 * b)})
g = theano.function([], [], updates={a: (a * 2), b: b})
f()
assert a.get_value(borrow=True).shape == (), a.get_value()
assert b.get_value(borrow=True).shape == (2, 3), b.get_value()
g()
assert a.get_value(borrow=True).shape == (), a.get_value()
assert b.get_value(borrow=True).shape == (2, 3), b.get_value()
def test_update_equiv(self):
# Like test_update_same, but the update expression is simplified until
# it is found to be equal to the original variable
a = shared(1., 'a')
b = shared(numpy.ones((2, 3)), 'b')
# The order of the variables is not determined, so we try
# both shared variables.
f = theano.function([], [], updates={a: a, b: (2 * b - b)})
g = theano.function([], [], updates={a: (a * 2 - a), b: b})
f()
assert a.get_value(borrow=True).shape == (), a.get_value()
assert b.get_value(borrow=True).shape == (2, 3), b.get_value()
g()
assert a.get_value(borrow=True).shape == (), a.get_value()
assert b.get_value(borrow=True).shape == (2, 3), b.get_value()
class Test_aliasing_rules(unittest.TestCase): class Test_aliasing_rules(unittest.TestCase):
""" """
......
...@@ -716,8 +716,7 @@ class VM_Linker(link.LocalLinker): ...@@ -716,8 +716,7 @@ class VM_Linker(link.LocalLinker):
update_storage = [] update_storage = []
update_in_from_out = {} update_in_from_out = {}
for (ivar, ovar) in updated_vars.items(): for (ivar, ovar) in updated_vars.items():
if ivar != ovar: update_in_from_out[vars_idx[ovar]] = vars_idx[ivar]
update_in_from_out[vars_idx[ovar]] = vars_idx[ivar]
for oidx in output_vars: for oidx in output_vars:
if oidx in update_in_from_out: if oidx in update_in_from_out:
update_storage.append(update_in_from_out[oidx]) update_storage.append(update_in_from_out[oidx])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论