提交 6235e400 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Missing imports

上级 5e10c1f1
...@@ -49,7 +49,8 @@ import numpy ...@@ -49,7 +49,8 @@ import numpy
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
from theano import gof from theano import gof
from theano.tensor import opt from theano.tensor import opt, TensorVariable
from theano.tensor.sharedvar import TensorSharedVariable
from theano import tensor from theano import tensor
from theano import config from theano import config
from theano.updates import Updates from theano.updates import Updates
...@@ -435,7 +436,7 @@ def scan(fn, ...@@ -435,7 +436,7 @@ def scan(fn,
pos = len(lengths) pos = len(lengths)
for sv in shared_inputs: for sv in shared_inputs:
if sv in update_d: if sv in update_d:
if isinstance(sv, TensorType): if isinstance(sv, (TensorVariable, TensorSharedVariable)):
# We can treat it as a sit sot # We can treat it as a sit sot
nw_state = scan_utils.expand( nw_state = scan_utils.expand(
tensor.unbroadcast(tensor.shape_padleft(sv, 0), T)) tensor.unbroadcast(tensor.shape_padleft(sv, 0), T))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论