提交 f737996d authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Fix redefinitions

上级 5c4c6b55
...@@ -22,6 +22,7 @@ from typing import ( ...@@ -22,6 +22,7 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
cast, cast,
overload,
) )
import numpy as np import numpy as np
...@@ -1301,9 +1302,31 @@ def clone_get_equiv( ...@@ -1301,9 +1302,31 @@ def clone_get_equiv(
return memo return memo
@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...
@overload
def general_toposort( def general_toposort(
outputs: Iterable[T], outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, list[T]]], deps: Callable[[T], Union[OrderedSet, list[T]]],
compute_deps_cache: None,
deps_cache: None,
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...
def general_toposort(
outputs: Iterable[T],
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
compute_deps_cache: Optional[ compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, list[T]]]] Callable[[T], Optional[Union[OrderedSet, list[T]]]]
] = None, ] = None,
...@@ -1345,7 +1368,7 @@ def general_toposort( ...@@ -1345,7 +1368,7 @@ def general_toposort(
if deps_cache is None: if deps_cache is None:
deps_cache = {} deps_cache = {}
def _compute_deps_cache(io): def _compute_deps_cache_(io):
if io not in deps_cache: if io not in deps_cache:
d = deps(io) d = deps(io)
...@@ -1363,6 +1386,8 @@ def general_toposort( ...@@ -1363,6 +1386,8 @@ def general_toposort(
else: else:
return deps_cache[io] return deps_cache[io]
_compute_deps_cache = _compute_deps_cache_
else: else:
_compute_deps_cache = compute_deps_cache _compute_deps_cache = compute_deps_cache
...@@ -1451,15 +1476,14 @@ def io_toposort( ...@@ -1451,15 +1476,14 @@ def io_toposort(
) )
return order return order
compute_deps = None
compute_deps_cache = None
iset = set(inputs) iset = set(inputs)
deps_cache: dict = {}
if not orderings: # ordering can be None or empty dict if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering. # Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up. # Also include the cache in the function itself for speed up.
deps_cache: dict = {}
def compute_deps_cache(obj): def compute_deps_cache(obj):
if obj in deps_cache: if obj in deps_cache:
return deps_cache[obj] return deps_cache[obj]
...@@ -1478,6 +1502,14 @@ def io_toposort( ...@@ -1478,6 +1502,14 @@ def io_toposort(
deps_cache[obj] = rval deps_cache[obj] = rval
return rval return rval
topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
else: else:
# the inputs are used only here in the function that decides what # the inputs are used only here in the function that decides what
# 'predecessors' to explore # 'predecessors' to explore
...@@ -1494,13 +1526,13 @@ def io_toposort( ...@@ -1494,13 +1526,13 @@ def io_toposort(
assert not orderings.get(obj, None) assert not orderings.get(obj, None)
return rval return rval
topo = general_toposort( topo = general_toposort(
outputs, outputs,
deps=compute_deps, deps=compute_deps,
compute_deps_cache=compute_deps_cache, compute_deps_cache=None,
deps_cache=deps_cache, deps_cache=None,
clients=clients, clients=clients,
) )
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
...@@ -2405,13 +2405,15 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2405,13 +2405,15 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
if node is not current_node: if node is not current_node:
q.append(node) q.append(node)
chin = None chin: Optional[Callable] = None
if self.tracks_on_change_inputs: if self.tracks_on_change_inputs:
def chin(node, i, r, new_r, reason): def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str): if node is not current_node and not isinstance(node, str):
q.append(node) q.append(node)
chin = chin_
u = self.attach_updater( u = self.attach_updater(
fgraph, importer, None, chin=chin, name=getattr(self, "name", None) fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
) )
......
...@@ -1403,7 +1403,7 @@ class TestMinMax: ...@@ -1403,7 +1403,7 @@ class TestMinMax:
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip = makeTester( TestClip1 = makeTester(
name="ClipTester", name="ClipTester",
op=clip, op=clip,
expected=lambda x, y, z: np.clip(x, y, z), expected=lambda x, y, z: np.clip(x, y, z),
...@@ -1470,7 +1470,7 @@ TestBackwardsClip = makeTester( ...@@ -1470,7 +1470,7 @@ TestBackwardsClip = makeTester(
) )
class TestClip: class TestClip2:
def test_complex_value(self): def test_complex_value(self):
for dtype in ["complex64", "complex128"]: for dtype in ["complex64", "complex128"]:
a = vector(dtype=dtype) a = vector(dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论