提交 f7774000 authored 作者: James Bergstra's avatar James Bergstra

fixed duplicate-update bug in pfunc

上级 c23320a7
...@@ -157,6 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace ...@@ -157,6 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
for (store_into, update_val) in iter_over_pairs(updates): for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable): if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable', store_into) raise TypeError('update target must be a SharedVariable', store_into)
if store_into in new_updates:
raise ValueError('this shared variable already has an update expression',
(store_into, new_updates[store_into]))
update_val = v_clone(store_into.filter_update(update_val)) update_val = v_clone(store_into.filter_update(update_val))
if update_val.type != store_into.type: if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable', raise TypeError('an update must have the same type as the original shared variable',
......
...@@ -2,7 +2,7 @@ import numpy ...@@ -2,7 +2,7 @@ import numpy
import unittest import unittest
import copy import copy
import theano import theano
from theano.tensor import Tensor, dmatrix, dvector, lscalar from theano.tensor import Tensor, dmatrix, dvector, lscalar, dmatrices
from theano import tensor from theano import tensor
from theano.compile.sharedvalue import * from theano.compile.sharedvalue import *
...@@ -193,6 +193,11 @@ class Test_pfunc(unittest.TestCase): ...@@ -193,6 +193,11 @@ class Test_pfunc(unittest.TestCase):
inc_by_y() inc_by_y()
self.failUnless(x.value == 1) self.failUnless(x.value == 1)
def test_duplicate_updates(self):
x, y = dmatrices('x', 'y')
z = shared(numpy.ones((2,3)))
self.failUnlessRaises(ValueError, theano.function, [x,y], [z], updates=[(z, z+x+y), (z, z-x)])
def test_givens(self): def test_givens(self):
x = shared(0) x = shared(0)
assign = pfunc([], x, givens = {x: 3}) assign = pfunc([], x, givens = {x: 3})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论