提交 de9ad202 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parameterize Apply type by Op

上级 a1d2fffd
...@@ -11,6 +11,7 @@ from typing import ( ...@@ -11,6 +11,7 @@ from typing import (
Deque, Deque,
Dict, Dict,
Generator, Generator,
Generic,
Hashable, Hashable,
Iterable, Iterable,
Iterator, Iterator,
...@@ -45,6 +46,8 @@ if TYPE_CHECKING: ...@@ -45,6 +46,8 @@ if TYPE_CHECKING:
from aesara.graph.type import Type from aesara.graph.type import Type
OpType = TypeVar("OpType", bound="Op")
T = TypeVar("T", bound="Node") T = TypeVar("T", bound="Node")
NoParams = object() NoParams = object()
NodeAndChildren = Tuple[T, Optional[Iterable[T]]] NodeAndChildren = Tuple[T, Optional[Iterable[T]]]
...@@ -71,7 +74,7 @@ class Node(MetaObject): ...@@ -71,7 +74,7 @@ class Node(MetaObject):
raise NotImplementedError() raise NotImplementedError()
class Apply(Node): class Apply(Node, Generic[OpType]):
"""A `Node` representing the application of an operation to inputs. """A `Node` representing the application of an operation to inputs.
Basically, an `Apply` instance is an object that represents the Basically, an `Apply` instance is an object that represents the
...@@ -107,13 +110,13 @@ class Apply(Node): ...@@ -107,13 +110,13 @@ class Apply(Node):
""" """
def __init__( def __init__(
self, op: "Op", inputs: Sequence["Variable"], outputs: Sequence["Variable"] self, op: OpType, inputs: Sequence["Variable"], outputs: Sequence["Variable"]
): ):
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, Sequence):
raise TypeError("The inputs of an Apply must be a list or tuple") raise TypeError("The inputs of an Apply must be a sequence type")
if not isinstance(outputs, (list, tuple)): if not isinstance(outputs, Sequence):
raise TypeError("The output of an Apply must be a list or tuple") raise TypeError("The output of an Apply must be a sequence type")
self.op = op self.op = op
self.inputs: List[Variable] = [] self.inputs: List[Variable] = []
...@@ -197,7 +200,7 @@ class Apply(Node): ...@@ -197,7 +200,7 @@ class Apply(Node):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def clone(self, clone_inner_graph: bool = False) -> "Apply": def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]":
r"""Clone this `Apply` instance. r"""Clone this `Apply` instance.
Parameters Parameters
...@@ -218,8 +221,8 @@ class Apply(Node): ...@@ -218,8 +221,8 @@ class Apply(Node):
new_op = self.op new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() new_op = new_op.clone() # type: ignore
cp = self.__class__( cp = self.__class__(
new_op, self.inputs, [output.clone() for output in self.outputs] new_op, self.inputs, [output.clone() for output in self.outputs]
...@@ -229,7 +232,7 @@ class Apply(Node): ...@@ -229,7 +232,7 @@ class Apply(Node):
def clone_with_new_inputs( def clone_with_new_inputs(
self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False
) -> "Apply": ) -> "Apply[OpType]":
r"""Duplicate this `Apply` instance in a new graph. r"""Duplicate this `Apply` instance in a new graph.
Parameters Parameters
...@@ -273,8 +276,8 @@ class Apply(Node): ...@@ -273,8 +276,8 @@ class Apply(Node):
if remake_node: if remake_node:
new_op = self.op new_op = self.op
if isinstance(new_op, HasInnerGraph) and clone_inner_graph: if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore
new_op = new_op.clone() new_op = new_op.clone() # type: ignore
new_node = new_op.make_node(*new_inputs) new_node = new_op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag) new_node.tag = copy(self.tag).__update__(new_node.tag)
......
...@@ -509,7 +509,7 @@ class Op(MetaObject): ...@@ -509,7 +509,7 @@ class Op(MetaObject):
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
if debug and hasattr(self, "debug_perform"): if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform # type: ignore p = node.op.debug_perform
else: else:
p = node.op.perform p = node.op.perform
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论