提交 59f40d38 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

add condition to the list of outputs

上级 19877f2e
...@@ -409,8 +409,13 @@ def scan(fn, ...@@ -409,8 +409,13 @@ def scan(fn,
# 5.1 Construct list of shared variables with updates (those that # 5.1 Construct list of shared variables with updates (those that
# can be treated as states (i.e. of TensorType) and those that can not # can be treated as states (i.e. of TensorType) and those that can not
# (like Random States) # (like Random States)
if cond is not None:
_cond = [cond]
else:
_cond = []
rvals = rebuild_collect_shared( rvals = rebuild_collect_shared(
states_and_outputs + [cond], states_and_outputs + _cond,
updates=updates, updates=updates,
rebuild_strict=True, rebuild_strict=True,
copy_inputs_over=True, copy_inputs_over=True,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论