提交 f737996d authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Fix redefinitions

上级 5c4c6b55
......@@ -22,6 +22,7 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
import numpy as np
......@@ -1301,9 +1302,31 @@ def clone_get_equiv(
return memo
@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...
@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, list[T]]],
compute_deps_cache: None,
deps_cache: None,
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...
def general_toposort(
outputs: Iterable[T],
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
] = None,
......@@ -1345,7 +1368,7 @@ def general_toposort(
if deps_cache is None:
deps_cache = {}
def _compute_deps_cache(io):
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)
......@@ -1363,6 +1386,8 @@ def general_toposort(
else:
return deps_cache[io]
_compute_deps_cache = _compute_deps_cache_
else:
_compute_deps_cache = compute_deps_cache
......@@ -1451,15 +1476,14 @@ def io_toposort(
)
return order
compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache: dict = {}
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
deps_cache: dict = {}
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
......@@ -1478,6 +1502,14 @@ def io_toposort(
deps_cache[obj] = rval
return rval
topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
......@@ -1494,13 +1526,13 @@ def io_toposort(
assert not orderings.get(obj, None)
return rval
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]
......
......@@ -2405,13 +2405,15 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if node is not current_node:
q.append(node)
chin = None
chin: Optional[Callable] = None
if self.tracks_on_change_inputs:
def chin(node, i, r, new_r, reason):
def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
q.append(node)
chin = chin_
u = self.attach_updater(
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
)
......
......@@ -1403,7 +1403,7 @@ class TestMinMax:
rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip = makeTester(
TestClip1 = makeTester(
name="ClipTester",
op=clip,
expected=lambda x, y, z: np.clip(x, y, z),
......@@ -1470,7 +1470,7 @@ TestBackwardsClip = makeTester(
)
class TestClip:
class TestClip2:
def test_complex_value(self):
for dtype in ["complex64", "complex128"]:
a = vector(dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论