提交 7092f551 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup imports in graph/rewriting/basic.py

上级 2d46d60e
...@@ -13,9 +13,8 @@ from collections import Counter, UserList, defaultdict, deque ...@@ -13,9 +13,8 @@ from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from functools import _compose_mro, partial # type: ignore from functools import _compose_mro, partial # type: ignore
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Literal from typing import Literal
import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import destroyhandler as dh from pytensor.graph import destroyhandler as dh
from pytensor.graph.basic import ( from pytensor.graph.basic import (
...@@ -30,15 +29,12 @@ from pytensor.graph.basic import ( ...@@ -30,15 +29,12 @@ from pytensor.graph.basic import (
from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten from pytensor.utils import flatten
if TYPE_CHECKING:
from pytensor.graph.rewriting.unify import Var
_logger = logging.getLogger("pytensor.graph.rewriting.basic") _logger = logging.getLogger("pytensor.graph.rewriting.basic")
RemoveKeyType = Literal["remove"] RemoveKeyType = Literal["remove"]
...@@ -1406,8 +1402,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1406,8 +1402,6 @@ class PatternNodeRewriter(NodeRewriter):
frequent `Op`, which will prevent the rewrite from being tried as often. frequent `Op`, which will prevent the rewrite from being tried as often.
""" """
from pytensor.graph.rewriting.unify import convert_strs_to_vars
var_map: dict[str, Var] = {} var_map: dict[str, Var] = {}
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
...@@ -1449,9 +1443,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1449,9 +1443,6 @@ class PatternNodeRewriter(NodeRewriter):
if ret is not False and ret is not None: if ret is not False and ret is not None:
return dict(zip(real_node.outputs, ret, strict=True)) return dict(zip(real_node.outputs, ret, strict=True))
if node.op != self.op:
return False
if len(node.outputs) != 1: if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes # PatternNodeRewriter doesn't support replacing multi-output nodes
return False return False
...@@ -1480,11 +1471,13 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1480,11 +1471,13 @@ class PatternNodeRewriter(NodeRewriter):
[old_out] = node.outputs [old_out] = node.outputs
if not old_out.type.is_super(ret.type): if not old_out.type.is_super(ret.type):
from pytensor.tensor.type import TensorType
# Type doesn't match # Type doesn't match
if not ( if not (
self.allow_cast self.allow_cast
and isinstance(old_out.type, pytensor.tensor.TensorType) and isinstance(old_out.type, TensorType)
and isinstance(ret.type, pytensor.tensor.TensorType) and isinstance(ret.type, TensorType)
): ):
return False return False
...@@ -2736,10 +2729,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): ...@@ -2736,10 +2729,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
otherwise. otherwise.
""" """
if isinstance(f_or_fgraph, pytensor.compile.function.types.Function): from pytensor.compile.function.types import Function
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph): if isinstance(f_or_fgraph, FunctionGraph):
fgraph = f_or_fgraph fgraph = f_or_fgraph
elif isinstance(f_or_fgraph, Function):
fgraph = f_or_fgraph.maker.fgraph
else: else:
raise ValueError("The type of f_or_fgraph is not supported") raise ValueError("The type of f_or_fgraph is not supported")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论