提交 020ebbd5 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

comment about number of steps

上级 5f470abb
......@@ -1253,6 +1253,12 @@ class Scan(PureOp):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)):
outs = [outs]
# `grad_step` equals the number of steps the original scan node has
# done (if the original scan is a while loop than this number is the
# length of the output sequence)
# We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot
if self.n_nit_sot > 0:
grad_steps = self.outer_nitsot_outs(outs)[0].shape[0]
elif self.n_sit_sot > 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论