提交 34dcbcf1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move InconsistencyError and MissingInputError to aesara.graph.utils

上级 9956162f
......@@ -31,9 +31,8 @@ from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, HasInnerGraph, Op
from aesara.graph.utils import MethodNotDefined
from aesara.graph.utils import InconsistencyError, MethodNotDefined
from aesara.link.basic import Container, LocalLinker
from aesara.link.utils import map_storage, raise_with_op
from aesara.printing import _debugprint
......
......@@ -27,9 +27,9 @@ from aesara.graph.basic import (
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph
from aesara.graph.utils import get_variable_trace_string
from aesara.graph.utils import InconsistencyError, get_variable_trace_string
from aesara.link.basic import Container
from aesara.link.utils import raise_with_op
......
......@@ -10,7 +10,7 @@ import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.features import AlreadyThere, Bookkeeper
from aesara.graph.fg import InconsistencyError
from aesara.graph.utils import InconsistencyError
from aesara.misc.ordered_set import OrderedSet
......
......@@ -10,6 +10,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort
from aesara.graph.utils import InconsistencyError
class AlreadyThere(Exception):
......@@ -641,9 +642,7 @@ class ReplaceValidate(History, Validator):
def validate(self, fgraph):
if self.fail_validate:
self.fail_validate = False
raise aesara.graph.fg.InconsistencyError(
"Trying to reintroduce a removed node"
)
raise InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper):
......@@ -789,7 +788,7 @@ class NoOutputFromInplace(Feature):
op = node.op
out_idx = node.outputs.index(out)
if out_idx in op.destroy_map:
raise aesara.graph.fg.InconsistencyError(
raise InconsistencyError(
"A function graph Feature has requested that outputs of the graph "
"be prevented from being the result of in-place "
f"operations. This has prevented the output {out} from "
......
......@@ -9,35 +9,10 @@ from aesara.graph.basic import Apply, Constant, Variable, applys_between
from aesara.graph.basic import as_string as graph_as_string
from aesara.graph.basic import clone_get_equiv, graph_inputs, io_toposort, vars_between
from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate
from aesara.graph.utils import MetaObject, TestValueError, get_variable_trace_string
from aesara.graph.utils import MetaObject, MissingInputError, TestValueError
from aesara.misc.ordered_set import OrderedSet
class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid.
"""
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
"""
def __init__(self, *args, **kwargs):
if kwargs:
# The call to list is needed for Python 3
assert list(kwargs.keys()) == ["variable"]
error_msg = get_variable_trace_string(kwargs["variable"])
if error_msg:
args = args + (error_msg,)
s = "\n".join(args) # Needed to have the new line print correctly
super().__init__(s)
class FunctionGraph(MetaObject):
"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and
......
......@@ -31,9 +31,9 @@ from aesara.graph.basic import (
vars_between,
)
from aesara.graph.features import Feature, NodeFinder
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.utils import AssocList
from aesara.graph.utils import AssocList, InconsistencyError
from aesara.misc.ordered_set import OrderedSet
from aesara.utils import flatten
......
......@@ -145,6 +145,31 @@ def get_variable_trace_string(v):
return sio.getvalue()
class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid.
"""
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
"""
def __init__(self, *args, **kwargs):
if kwargs:
# The call to list is needed for Python 3
assert list(kwargs.keys()) == ["variable"]
error_msg = get_variable_trace_string(kwargs["variable"])
if error_msg:
args = args + (error_msg,)
s = "\n".join(args) # Needed to have the new line print correctly
super().__init__(s)
class TestValueError(Exception):
"""Base exception class for all test value errors."""
......
......@@ -8,9 +8,8 @@ from aesara.compile import SharedVariable
from aesara.compile.function.pfunc import construct_pfunc_ins_and_outs
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable, clone_replace, graph_inputs
from aesara.graph.fg import MissingInputError
from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError
from aesara.graph.utils import MissingInputError, TestValueError
from aesara.scan import utils
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import safe_new, traverse
......
......@@ -71,8 +71,8 @@ from aesara.graph.basic import (
io_connection_pattern,
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import MissingInputError
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
......
......@@ -26,10 +26,11 @@ from aesara.graph.basic import (
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.utils import InconsistencyError
from aesara.scan.op import Scan, ScanInfo
from aesara.scan.utils import (
ScanArgs,
......
......@@ -23,7 +23,7 @@ from aesara.graph.basic import (
equal_computations,
io_toposort,
)
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.opt import (
GlobalOptimizer,
......@@ -35,6 +35,7 @@ from aesara.graph.opt import (
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.utils import (
InconsistencyError,
MethodNotDefined,
TestValueError,
get_variable_trace_string,
......
......@@ -147,7 +147,6 @@ from aesara.compile.mode import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.fg import InconsistencyError
from aesara.graph.op import COp, Op
from aesara.graph.opt import (
EquilibriumOptimizer,
......@@ -158,7 +157,7 @@ from aesara.graph.opt import (
)
from aesara.graph.optdb import SequenceDB
from aesara.graph.params_type import ParamsType
from aesara.graph.utils import MethodNotDefined, TestValueError
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at
......
......@@ -8,7 +8,7 @@ from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.io import In
from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.graph.fg import MissingInputError
from aesara.graph.utils import MissingInputError
from aesara.misc.safe_asarray import _asarray
from aesara.tensor.math import sum as at_sum
from aesara.tensor.type import (
......
......@@ -17,8 +17,8 @@ from aesara.configdefaults import config
from aesara.gpuarray import gpuarray_shared_constructor
from aesara.gpuarray.blas import GpuGemm
from aesara.graph.basic import Constant
from aesara.graph.fg import MissingInputError
from aesara.graph.opt import OpKeyOptimizer, PatternSub
from aesara.graph.utils import MissingInputError
from aesara.tensor.math import dot
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
......
......@@ -6,7 +6,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable, clone
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import (
NavigatorOptimizer,
......@@ -16,6 +16,7 @@ from aesara.graph.opt import (
TopoOptimizer,
)
from aesara.graph.type import Type
from aesara.graph.utils import InconsistencyError
from tests.unittest_tools import assertFailure_fast
......
......@@ -4,7 +4,8 @@ import numpy as np
import pytest
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph, MissingInputError
from aesara.graph.fg import FunctionGraph
from aesara.graph.utils import MissingInputError
from tests.graph.utils import MyConstant, MyVariable, MyVariable2, op1, op2, op3
......
......@@ -29,8 +29,8 @@ from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from aesara.graph.basic import Apply, ancestors
from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op
from aesara.graph.utils import MissingInputError
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import assert_op
from aesara.scan.basic import scan
......
......@@ -30,6 +30,7 @@ from aesara.configdefaults import config
from aesara.gradient import grad
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import in2out
from aesara.graph.utils import InconsistencyError
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace
from aesara.tensor.basic import as_tensor_variable
......@@ -39,7 +40,6 @@ from aesara.tensor.blas import (
Gemm,
Gemv,
Ger,
InconsistencyError,
_as_scalar,
_dot22,
_dot22scalar,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论