提交 6b4e56d4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor conditions

上级 db0bc6ca
...@@ -5,8 +5,6 @@ from theano.gof import Container, Variable, generic, graph, Constant, Value ...@@ -5,8 +5,6 @@ from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
import numpy # for backport to 2.4, to get any(). import numpy # for backport to 2.4, to get any().
import theano
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, strict=False, def __init__(self, variable, default=None, name=None, mutable=False, strict=False,
...@@ -119,8 +117,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -119,8 +117,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
assert v is not None assert v is not None
if v.owner: if v.owner:
clone_a(v.owner) clone_a(v.owner)
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable) and v not in clone_d:
if v not in shared_inputs and v not in clone_d: if v not in shared_inputs:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, 'default_update'): if hasattr(v, 'default_update'):
...@@ -129,7 +127,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -129,7 +127,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
(isinstance(no_default_updates, list) and\ (isinstance(no_default_updates, list) and\
v not in no_default_updates): v not in no_default_updates):
# Do not use default_update if a "real" update was provided # Do not use default_update if a "real" update was provided
if v not in update_d and v not in clone_d: if v not in update_d:
v_update = v.filter_update(v.default_update) v_update = v.filter_update(v.default_update)
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError('an update must have the same type as the original shared variable', raise TypeError('an update must have the same type as the original shared variable',
...@@ -156,7 +154,6 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -156,7 +154,6 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
except: except:
pass pass
for v_orig, v_repl in givens: for v_orig, v_repl in givens:
if not isinstance(v_orig, Variable): if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig) raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl, Variable): if not isinstance(v_repl, Variable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论