提交 7092f551 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cleanup imports in graph/rewriting/basic.py

上级 2d46d60e
......@@ -13,9 +13,8 @@ from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence
from functools import _compose_mro, partial # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Literal
from typing import Literal
import pytensor
from pytensor.configdefaults import config
from pytensor.graph import destroyhandler as dh
from pytensor.graph.basic import (
......@@ -30,15 +29,12 @@ from pytensor.graph.basic import (
from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.rewriting.unify import Var, convert_strs_to_vars
from pytensor.graph.utils import AssocList, InconsistencyError
from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import flatten
if TYPE_CHECKING:
from pytensor.graph.rewriting.unify import Var
_logger = logging.getLogger("pytensor.graph.rewriting.basic")
RemoveKeyType = Literal["remove"]
......@@ -1406,8 +1402,6 @@ class PatternNodeRewriter(NodeRewriter):
frequent `Op`, which will prevent the rewrite from being tried as often.
"""
from pytensor.graph.rewriting.unify import convert_strs_to_vars
var_map: dict[str, Var] = {}
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)
self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map)
......@@ -1449,9 +1443,6 @@ class PatternNodeRewriter(NodeRewriter):
if ret is not False and ret is not None:
return dict(zip(real_node.outputs, ret, strict=True))
if node.op != self.op:
return False
if len(node.outputs) != 1:
# PatternNodeRewriter doesn't support replacing multi-output nodes
return False
......@@ -1480,11 +1471,13 @@ class PatternNodeRewriter(NodeRewriter):
[old_out] = node.outputs
if not old_out.type.is_super(ret.type):
from pytensor.tensor.type import TensorType
# Type doesn't match
if not (
self.allow_cast
and isinstance(old_out.type, pytensor.tensor.TensorType)
and isinstance(ret.type, pytensor.tensor.TensorType)
and isinstance(old_out.type, TensorType)
and isinstance(ret.type, TensorType)
):
return False
......@@ -2736,10 +2729,12 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
otherwise.
"""
if isinstance(f_or_fgraph, pytensor.compile.function.types.Function):
fgraph = f_or_fgraph.maker.fgraph
elif isinstance(f_or_fgraph, pytensor.graph.fg.FunctionGraph):
from pytensor.compile.function.types import Function
if isinstance(f_or_fgraph, FunctionGraph):
fgraph = f_or_fgraph
elif isinstance(f_or_fgraph, Function):
fgraph = f_or_fgraph.maker.fgraph
else:
raise ValueError("The type of f_or_fgraph is not supported")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论