提交 6b4e56d4 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor conditions

上级 db0bc6ca
......@@ -5,8 +5,6 @@ from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared
import numpy # for backport to 2.4, to get any().
import theano
class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, strict=False,
......@@ -119,8 +117,8 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
assert v is not None
if v.owner:
clone_a(v.owner)
elif isinstance(v, SharedVariable):
if v not in shared_inputs and v not in clone_d:
elif isinstance(v, SharedVariable) and v not in clone_d:
if v not in shared_inputs:
shared_inputs.append(v)
if hasattr(v, 'default_update'):
......@@ -129,7 +127,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
(isinstance(no_default_updates, list) and\
v not in no_default_updates):
# Do not use default_update if a "real" update was provided
if v not in update_d and v not in clone_d:
if v not in update_d:
v_update = v.filter_update(v.default_update)
if v_update.type != v.type:
raise TypeError('an update must have the same type as the original shared variable',
......@@ -156,7 +154,6 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
except:
pass
for v_orig, v_repl in givens:
if not isinstance(v_orig, Variable):
raise TypeError('given keys must be Variable', v_orig)
if not isinstance(v_repl, Variable):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论