提交 4eea9f79 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix typing issues in aesara.compile.builders

上级 f9618a62
"""Define new Ops from existing Ops""" """Define new Ops from existing Ops"""
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional, Sequence, cast
import aesara.tensor as at import aesara.tensor as at
from aesara.compile.function.pfunc import rebuild_collect_shared from aesara.compile.function.pfunc import rebuild_collect_shared
...@@ -349,7 +349,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -349,7 +349,7 @@ class OpFromGraph(Op, HasInnerGraph):
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
new = rebuild_collect_shared( new = rebuild_collect_shared(
outputs, cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, shared_vars)), replace=dict(zip(self.shared_inputs, shared_vars)),
copy_inputs_over=False, copy_inputs_over=False,
...@@ -357,7 +357,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -357,7 +357,7 @@ class OpFromGraph(Op, HasInnerGraph):
( (
local_inputs, local_inputs,
local_outputs, local_outputs,
[clone_d, update_d, update_expr, shared_inputs], (clone_d, update_d, update_expr, shared_inputs),
) = new ) = new
assert len(local_inputs) == len(inputs) + len(self.shared_inputs) assert len(local_inputs) == len(inputs) + len(self.shared_inputs)
assert len(local_outputs) == len(outputs) assert len(local_outputs) == len(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论