提交 dc92c85b authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove OrderedDict from tests/test_gradient

上级 79ce5106
from collections import OrderedDict
import numpy as np import numpy as np
import pytest import pytest
...@@ -637,7 +635,7 @@ def test_known_grads(): ...@@ -637,7 +635,7 @@ def test_known_grads():
for layer in layers: for layer in layers:
first = grad(cost, layer, disconnected_inputs="ignore") first = grad(cost, layer, disconnected_inputs="ignore")
known = OrderedDict(zip(layer, first)) known = dict(zip(layer, first))
full = grad( full = grad(
cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore" cost=None, known_grads=known, wrt=inputs, disconnected_inputs="ignore"
) )
...@@ -755,7 +753,7 @@ def test_subgraph_grad(): ...@@ -755,7 +753,7 @@ def test_subgraph_grad():
param_grad, next_grad = subgraph_grad( param_grad, next_grad = subgraph_grad(
wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i] wrt=params[i], end=grad_ends[i], start=next_grad, cost=costs[i]
) )
next_grad = OrderedDict(zip(grad_ends[i], next_grad)) next_grad = dict(zip(grad_ends[i], next_grad))
param_grads.extend(param_grad) param_grads.extend(param_grad)
pgrads = pytensor.function(inputs, param_grads) pgrads = pytensor.function(inputs, param_grads)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论