提交 afae8eb0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unused objects and functions from aesara.compile.function.types

上级 304169b8
...@@ -6,14 +6,12 @@ from aesara.compile.function.types import ( ...@@ -6,14 +6,12 @@ from aesara.compile.function.types import (
Supervisor, Supervisor,
UnusedInputError, UnusedInputError,
alias_root, alias_root,
check_equal,
convert_function_input, convert_function_input,
fgraph_updated_vars, fgraph_updated_vars,
get_info_on_inputs, get_info_on_inputs,
infer_reuse_pattern, infer_reuse_pattern,
insert_deepcopy, insert_deepcopy,
orig_function, orig_function,
register_checker,
std_fgraph, std_fgraph,
view_tree_set, view_tree_set,
) )
......
...@@ -9,7 +9,6 @@ import logging ...@@ -9,7 +9,6 @@ import logging
import time import time
import warnings import warnings
from itertools import chain from itertools import chain
from typing import List
import numpy as np import numpy as np
...@@ -36,8 +35,6 @@ from aesara.link.utils import raise_with_op ...@@ -36,8 +35,6 @@ from aesara.link.utils import raise_with_op
_logger = logging.getLogger("aesara.compile.function.types") _logger = logging.getLogger("aesara.compile.function.types")
__docformat__ = "restructuredtext en"
class UnusedInputError(Exception): class UnusedInputError(Exception):
""" """
...@@ -220,10 +217,6 @@ class AliasedMemoryError(Exception): ...@@ -220,10 +217,6 @@ class AliasedMemoryError(Exception):
""" """
###
# Function
###
# unique id object used as a placeholder for duplicate entries # unique id object used as a placeholder for duplicate entries
DUPLICATE = ["DUPLICATE"] DUPLICATE = ["DUPLICATE"]
...@@ -1173,9 +1166,6 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False): ...@@ -1173,9 +1166,6 @@ def _constructor_Function(maker, input_storage, inputs_data, trust_input=False):
copyreg.pickle(Function, _pickle_Function) copyreg.pickle(Function, _pickle_Function)
###
# FunctionMaker
###
def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
""" """
Insert deepcopy in the fgraph to break aliasing of outputs Insert deepcopy in the fgraph to break aliasing of outputs
...@@ -1276,9 +1266,6 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1276,9 +1266,6 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
break break
NODEFAULT = ["NODEFAULT"]
class FunctionMaker: class FunctionMaker:
""" """
`FunctionMaker` is the class to `create` `Function` instances. `FunctionMaker` is the class to `create` `Function` instances.
...@@ -1682,33 +1669,6 @@ class FunctionMaker: ...@@ -1682,33 +1669,6 @@ class FunctionMaker:
return fn return fn
def _constructor_FunctionMaker(kwargs):
# Needed for old pickle
# Old pickle have at least the problem that output_keys where not saved.
if config.unpickle_function:
if config.reoptimize_unpickled_function:
del kwargs["fgraph"]
return FunctionMaker(**kwargs)
else:
return None
__checkers: List = []
def check_equal(x, y):
for checker in __checkers:
try:
return checker(x, y)
except Exception:
continue
return x == y
def register_checker(checker):
__checkers.insert(0, checker)
def orig_function( def orig_function(
inputs, inputs,
outputs, outputs,
......
...@@ -68,28 +68,6 @@ _logger = logging.getLogger("aesara.tensor.basic") ...@@ -68,28 +68,6 @@ _logger = logging.getLogger("aesara.tensor.basic")
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
def check_equal_numpy(x, y):
"""
Return True iff x and y are equal.
Checks the dtype and shape if x and y are numpy.ndarray instances.
"""
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and np.all(abs(x - y) < 1e-10)
elif isinstance(x, (np.random.Generator, np.random.RandomState)) and isinstance(
y, (np.random.Generator, np.random.RandomState)
):
return builtins.all(
np.all(a == b) for a, b in zip(x.__getstate__(), y.__getstate__())
)
else:
return x == y
compile.register_checker(check_equal_numpy)
def __oplist_tag(thing, tag): def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", []) tags = getattr(thing, "__oplist_tags", [])
tags.append(tag) tags.append(tag)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论