提交 0f5a06d7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Use Scan.L_op instead of Scan.grad() to help speed up the second derivative

上级 7c22fa2e
...@@ -1931,8 +1931,7 @@ class Scan(PureOp): ...@@ -1931,8 +1931,7 @@ class Scan(PureOp):
return mappings return mappings
# GRAD FUNCTION # GRAD FUNCTION
def grad(self, inputs, dC_douts): def L_op(self, inputs, outs, dC_douts):
outs = self(*inputs)
if not isinstance(outs, (list, tuple)): if not isinstance(outs, (list, tuple)):
outs = [outs] outs = [outs]
# `grad_step` equals the number of steps the original scan node has # `grad_step` equals the number of steps the original scan node has
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论