提交 432ef124 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bug in pfunc ( bug = when you replaced a shared variable with givens,…

Fixed a bug in pfunc ( bug = when you replaced a shared variable with givens, the shared variable did not got replaced, and the update rule was executed ) that resulted in fixing the failing tests of scan. I also did a bit of cleaining in scan tests and code, and fix an unobserved bug in inplace computation of scan plus made sure scan knows (once the optimization is written) to only store the last k steps of an output
上级 c87a33e6
...@@ -60,7 +60,7 @@ FancyModule = Module ...@@ -60,7 +60,7 @@ FancyModule = Module
from printing import \ from printing import \
pprint, pp pprint, pp
from scan import scan from scan import scan,map
import tensor import tensor
import scalar import scalar
......
...@@ -5,6 +5,8 @@ from theano.gof import Container, Variable, generic, graph, Constant, Value ...@@ -5,6 +5,8 @@ from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
import numpy # for backport to 2.4, to get any(). import numpy # for backport to 2.4, to get any().
import theano
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, strict=False, def __init__(self, variable, default=None, name=None, mutable=False, strict=False,
...@@ -118,7 +120,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -118,7 +120,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
if v.owner: if v.owner:
clone_a(v.owner) clone_a(v.owner)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs and v not in clone_d:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, 'default_update'): if hasattr(v, 'default_update'):
...@@ -127,14 +129,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -127,14 +129,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
(isinstance(no_default_updates, list) and\ (isinstance(no_default_updates, list) and\
v not in no_default_updates): v not in no_default_updates):
# Do not use default_update if a "real" update was provided # Do not use default_update if a "real" update was provided
if v not in update_d: if v not in update_d and v not in clone_d:
v_update = v.filter_update(v.default_update) v_update = v.filter_update(v.default_update)
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError('an update must have the same type as the original shared variable', raise TypeError('an update must have the same type as the original shared variable',
(v, v.type, v_update, v_update.type)) (v, v.type, v_update, v_update.type))
update_d[v] = v_update update_d[v] = v_update
update_expr.append((v, v_update)) update_expr.append((v, v_update))
return clone_d.setdefault(v, v) return clone_d.setdefault(v, v)
def clone_a(a): def clone_a(a):
...@@ -155,6 +156,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -155,6 +156,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
except: except:
pass pass
for v_orig, v_repl in givens: for v_orig, v_repl in givens:
if not isinstance(v_orig, Variable): if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig) raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl, Variable): if not isinstance(v_repl, Variable):
...@@ -195,6 +197,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -195,6 +197,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
update_d[store_into] = update_val update_d[store_into] = update_val
update_expr.append((store_into, update_val)) update_expr.append((store_into, update_val))
# Elements of "outputs" are here cloned to "cloned_outputs" # Elements of "outputs" are here cloned to "cloned_outputs"
if isinstance(outputs, list): if isinstance(outputs, list):
cloned_outputs = [] cloned_outputs = []
...@@ -228,6 +231,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -228,6 +231,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
# If the variable to be updated is a shared variable not already # If the variable to be updated is a shared variable not already
# in shared_inputs, add it. # in shared_inputs, add it.
# Note: we extend update_expr while iterating over it. # Note: we extend update_expr while iterating over it.
i = 0 i = 0
while i<len(update_expr): while i<len(update_expr):
v, v_update = update_expr[i] v, v_update = update_expr[i]
......
...@@ -442,6 +442,25 @@ class Test_pfunc(unittest.TestCase): ...@@ -442,6 +442,25 @@ class Test_pfunc(unittest.TestCase):
# a is needed as input if y.default_update is used # a is needed as input if y.default_update is used
self.failUnlessRaises(TypeError, pfunc, [], x) self.failUnlessRaises(TypeError, pfunc, [], x)
def test_givens_replaces_shared_variable(self):
a = shared(1.,'a')
a.default_update = a+3.
b = tensor.scalar('b')
c = a + 10
f = pfunc([b],c, givens = {a:b})
assert len(f.maker.env.inputs) == 1
assert len(f.maker.env.outputs) == 1
def test_givens_replaces_shared_variable2(self):
a = shared(1.,'a')
a.default_update = a+3
c = a+ 10
f = pfunc([],c, givens = { a: a+10} )
assert f() == 21
assert f() == 34
if __name__ == '__main__': if __name__ == '__main__':
theano.config.mode = 'FAST_COMPILE' theano.config.mode = 'FAST_COMPILE'
Test_pfunc().test_default_scalar_container() Test_pfunc().test_default_scalar_container()
......
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论