提交 177124ba authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove OrderedDict from scan/op

上级 c7a99b60
......@@ -46,7 +46,6 @@ relies on the following elements to work properly :
import dataclasses
import logging
import time
from collections import OrderedDict
from collections.abc import Callable, Iterable
from copy import copy
from itertools import chain, product
......@@ -2188,7 +2187,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# corresponding outer inputs that the Scan would use as input for
# any given iteration. For simplicity, we use iteration 0.
inner_ins_shapes = []
out_equivalent = OrderedDict()
out_equivalent = {}
# The two following blocks are commented as it cause in some
# cases extra scans in the graph. See gh-XXX for the
......@@ -2469,7 +2468,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if (x in diff_inputs)
and get_inp_idx(self_inputs.index(x)) in connected_inputs
]
gmp = OrderedDict()
gmp = {}
# Required in case there is a pair of variables X and Y, with X
# used to compute Y, for both of which there is an external
......@@ -2478,7 +2477,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal
# to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
known_grads = {k.copy(): v for (k, v) in known_grads.items()}
grads = grad(
cost=None,
......@@ -2548,7 +2547,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt)
known_grads = OrderedDict()
known_grads = {}
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论