提交 de9ad202 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Apply type by Op

上级 a1d2fffd
......@@ -11,6 +11,7 @@ from typing import (
Deque,
Dict,
Generator,
Generic,
Hashable,
Iterable,
Iterator,
......@@ -45,6 +46,8 @@ if TYPE_CHECKING:
from aesara.graph.type import Type
OpType = TypeVar("OpType", bound="Op")
T = TypeVar("T", bound="Node")
NoParams = object()
NodeAndChildren = Tuple[T, Optional[Iterable[T]]]
......@@ -71,7 +74,7 @@ class Node(MetaObject):
raise NotImplementedError()
class Apply(Node):
class Apply(Node, Generic[OpType]):
"""A `Node` representing the application of an operation to inputs.
Basically, an `Apply` instance is an object that represents the
......@@ -107,13 +110,13 @@ class Apply(Node):
"""
def __init__(
self, op: "Op", inputs: Sequence["Variable"], outputs: Sequence["Variable"]
self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"]
):
if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple")
if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a sequence type")
if not isinstance(outputs, (list, tuple)):
raise TypeError("The output of an Apply must be a list or tuple")
if not isinstance(outputs, Sequence):
raise TypeError("The output of an Apply must be a sequence type")
self.op = op
self.inputs: List[Variable] = []
......@@ -197,7 +200,7 @@ class Apply(Node):
def __repr__(self):
return str(self)
def clone(self, clone_inner_graph: bool = False) -> "Apply":
def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]":
r"""Clone this `Apply` instance.
Parameters
......@@ -218,8 +221,8 @@ class Apply(Node):
new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() # type: ignore
cp = self.__class__(
new_op, self.inputs, [output.clone() for output in self.outputs]
......@@ -229,7 +232,7 @@ class Apply(Node):
def clone_with_new_inputs(
self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False
) -> "Apply":
) -> "Apply[OpType]":
r"""Duplicate this `Apply` instance in a new graph.
Parameters
......@@ -273,8 +276,8 @@ class Apply(Node):
if remake_node:
new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph:
new_op = new_op.clone()
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() # type: ignore
new_node = new_op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
......
......@@ -509,7 +509,7 @@ class Op(MetaObject):
node_output_storage = [storage_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform # type: ignore
p = node.op.debug_perform
else:
p = node.op.perform
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论