提交 f7774000 authored 作者: James Bergstra's avatar James Bergstra

fixed duplicate-update bug in pfunc

上级 c23320a7
......@@ -157,6 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
for (store_into, update_val) in iter_over_pairs(updates):
if not isinstance(store_into, SharedVariable):
raise TypeError('update target must be a SharedVariable', store_into)
if store_into in new_updates:
raise ValueError('this shared variable already has an update expression',
(store_into, new_updates[store_into]))
update_val = v_clone(store_into.filter_update(update_val))
if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable',
......
......@@ -2,7 +2,7 @@ import numpy
import unittest
import copy
import theano
from theano.tensor import Tensor, dmatrix, dvector, lscalar
from theano.tensor import Tensor, dmatrix, dvector, lscalar, dmatrices
from theano import tensor
from theano.compile.sharedvalue import *
......@@ -193,6 +193,11 @@ class Test_pfunc(unittest.TestCase):
inc_by_y()
self.failUnless(x.value == 1)
def test_duplicate_updates(self):
x, y = dmatrices('x', 'y')
z = shared(numpy.ones((2,3)))
self.failUnlessRaises(ValueError, theano.function, [x,y], [z], updates=[(z, z+x+y), (z, z-x)])
def test_givens(self):
x = shared(0)
assign = pfunc([], x, givens = {x: 3})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论