提交 8a4505c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move C-specific content from aesara.graph.op to aesara.link.c.op

上级 4eea9f79
...@@ -32,9 +32,10 @@ from aesara.configdefaults import config ...@@ -32,9 +32,10 @@ from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MethodNotDefined from aesara.graph.utils import InconsistencyError, MethodNotDefined
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
from aesara.link.c.op import COp
from aesara.link.utils import map_storage, raise_with_op from aesara.link.utils import map_storage, raise_with_op
from aesara.printing import _debugprint from aesara.printing import _debugprint
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
......
...@@ -11,8 +11,9 @@ import warnings ...@@ -11,8 +11,9 @@ import warnings
from typing import Dict, Tuple from typing import Dict, Tuple
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.link.c.op import COp
def register_view_op_c_code(type, code, version=()): def register_view_op_c_code(type, code, version=()):
......
...@@ -11,12 +11,13 @@ import aesara.tensor as at ...@@ -11,12 +11,13 @@ import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad_undefined from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, ExternalCOp, Op, _NoPythonOp from aesara.graph.op import Op, _NoPythonOp
from aesara.graph.opt import copy_stack_trace from aesara.graph.opt import copy_stack_trace
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import HideC from aesara.link.c.interface import HideC
from aesara.link.c.op import COp, ExternalCOp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int32_t from aesara.scalar import int32 as int32_t
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable
......
...@@ -10,9 +10,9 @@ from aesara.gpuarray.basic_ops import ( ...@@ -10,9 +10,9 @@ from aesara.gpuarray.basic_ops import (
) )
from aesara.gpuarray.opt_util import inplace_allocempty from aesara.gpuarray.opt_util import inplace_allocempty
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonCOp
from aesara.graph.opt import LocalOptGroup, in2out from aesara.graph.opt import LocalOptGroup, in2out
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import _NoPythonCOp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
......
...@@ -11,8 +11,8 @@ from aesara.gpuarray.basic_ops import ( ...@@ -11,8 +11,8 @@ from aesara.gpuarray.basic_ops import (
from aesara.gpuarray.type import gpu_context_type from aesara.gpuarray.type import gpu_context_type
from aesara.gradient import grad_undefined from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonExternalCOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import _NoPythonExternalCOp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor import as_tensor_variable from aesara.tensor import as_tensor_variable
from aesara.tensor.type import discrete_dtypes from aesara.tensor.type import discrete_dtypes
......
...@@ -13,8 +13,8 @@ from aesara.gpuarray.elemwise import GpuDimShuffle ...@@ -13,8 +13,8 @@ from aesara.gpuarray.elemwise import GpuDimShuffle
from aesara.gpuarray.type import GpuArrayType, gpu_context_type from aesara.gpuarray.type import GpuArrayType, gpu_context_type
from aesara.gradient import grad_undefined from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonExternalCOp
from aesara.graph.opt import local_optimizer from aesara.graph.opt import local_optimizer
from aesara.link.c.op import _NoPythonExternalCOp
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_canonicalize from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot from aesara.tensor.blas import batched_dot
......
...@@ -28,10 +28,10 @@ from aesara.gpuarray.basic_ops import ( ...@@ -28,10 +28,10 @@ from aesara.gpuarray.basic_ops import (
from aesara.gpuarray.type import GpuArraySharedVariable, get_context, gpu_context_type from aesara.gpuarray.type import GpuArraySharedVariable, get_context, gpu_context_type
from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import ExternalCOp, _NoPythonCOp, _NoPythonExternalCOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import CDataType, EnumList, Generic from aesara.graph.type import CDataType, EnumList, Generic
from aesara.link.c.cmodule import GCC_compiler from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, _NoPythonCOp, _NoPythonExternalCOp
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import as_scalar from aesara.scalar import as_scalar
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
......
...@@ -14,8 +14,9 @@ from aesara.gpuarray.basic_ops import ( ...@@ -14,8 +14,9 @@ from aesara.gpuarray.basic_ops import (
) )
from aesara.gpuarray.type import GpuArrayType, gpu_context_type from aesara.gpuarray.type import GpuArrayType, gpu_context_type
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import ExternalCOp, Op from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import ExternalCOp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor import math as tm from aesara.tensor import math as tm
......
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp
from aesara.graph.type import Generic from aesara.graph.type import Generic
from aesara.link.c.op import COp
from .basic_ops import as_gpuarray_variable, gpuarray_helper_inc_dir, infer_context_name from .basic_ops import as_gpuarray_variable, gpuarray_helper_inc_dir, infer_context_name
from .type import GpuArrayType from .type import GpuArrayType
......
...@@ -5,10 +5,11 @@ import numpy as np ...@@ -5,10 +5,11 @@ import numpy as np
import aesara.tensor as at import aesara.tensor as at
from aesara.gradient import grad_not_implemented from aesara.gradient import grad_not_implemented
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.link.c.interface import HideC from aesara.link.c.interface import HideC
from aesara.link.c.op import COp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int_t from aesara.scalar import int32 as int_t
from aesara.scalar import uint32 as size_t from aesara.scalar import uint32 as size_t
......
差异被折叠。
差异被折叠。
...@@ -7,9 +7,9 @@ import numpy as np ...@@ -7,9 +7,9 @@ import numpy as np
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import Generic from aesara.graph.type import Generic
from aesara.link.c.op import COp
class ExceptionType(Generic): class ExceptionType(Generic):
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import aesara.tensor as at import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp from aesara.link.c.op import COp
from aesara.scalar import Scalar, as_scalar from aesara.scalar import Scalar, as_scalar
from aesara.tensor.type import discrete_dtypes from aesara.tensor.type import discrete_dtypes
......
...@@ -25,9 +25,9 @@ from aesara.compile import optdb ...@@ -25,9 +25,9 @@ from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import undefined_grad from aesara.gradient import undefined_grad
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.opt import in2out, local_optimizer from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp, Op
from aesara.sandbox import multinomial from aesara.sandbox import multinomial
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int_t from aesara.scalar import int32 as int_t
......
...@@ -27,10 +27,10 @@ from aesara.configdefaults import config ...@@ -27,10 +27,10 @@ from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_undefined from aesara.gradient import DisconnectedType, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp
from aesara.graph.opt import MergeOptimizer from aesara.graph.opt import MergeOptimizer
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.graph.utils import MetaObject, MethodNotDefined from aesara.graph.utils import MetaObject, MethodNotDefined
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
from aesara.utils import ( from aesara.utils import (
......
...@@ -18,7 +18,8 @@ from aesara import scalar as aes ...@@ -18,7 +18,8 @@ from aesara import scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse.type import SparseType, _is_sparse from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.utils import hash_from_sparse from aesara.sparse.utils import hash_from_sparse
......
...@@ -5,8 +5,8 @@ import aesara ...@@ -5,8 +5,8 @@ import aesara
import aesara.scalar as aes import aesara.scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, _NoPythonCOp
from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer
from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse from aesara.sparse import basic as sparse
from aesara.sparse.basic import ( from aesara.sparse.basic import (
......
...@@ -23,10 +23,11 @@ from aesara import scalar as aes ...@@ -23,10 +23,11 @@ from aesara import scalar as aes
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint from aesara.printing import min_informative_str, pprint
from aesara.raise_op import CheckAndRaise, assert_op from aesara.raise_op import CheckAndRaise, assert_op
......
...@@ -147,7 +147,7 @@ from aesara.compile.mode import optdb ...@@ -147,7 +147,7 @@ from aesara.compile.mode import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, view_roots from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.opt import ( from aesara.graph.opt import (
EquilibriumOptimizer, EquilibriumOptimizer,
GlobalOptimizer, GlobalOptimizer,
...@@ -158,6 +158,7 @@ from aesara.graph.opt import ( ...@@ -158,6 +158,7 @@ from aesara.graph.opt import (
from aesara.graph.optdb import SequenceDB from aesara.graph.optdb import SequenceDB
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.link.c.op import COp
from aesara.printing import FunctionPrinter, debugprint, pprint from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at from aesara.tensor import basic as at
......
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.op import COp
from aesara.graph.opt import in2out from aesara.graph.opt import in2out
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor.blas import ( from aesara.tensor.blas import (
......
...@@ -8,10 +8,10 @@ from aesara.configdefaults import config ...@@ -8,10 +8,10 @@ from aesara.configdefaults import config
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
from aesara.graph.op import COp, ExternalCOp, OpenMPOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.c.basic import failure_code from aesara.link.c.basic import failure_code
from aesara.link.c.op import COp, ExternalCOp, OpenMPOp
from aesara.misc.frozendict import frozendict from aesara.misc.frozendict import frozendict
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, Printer, pprint from aesara.printing import FunctionPrinter, Printer, pprint
......
...@@ -11,9 +11,10 @@ from aesara.gradient import ( ...@@ -11,9 +11,10 @@ from aesara.gradient import (
grad_undefined, grad_undefined,
) )
from aesara.graph.basic import Apply, Variable, equal_computations from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList, Generic from aesara.graph.type import EnumList, Generic
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t from aesara.scalar import int32 as int_t
......
...@@ -7,9 +7,10 @@ from aesara import config, printing ...@@ -7,9 +7,10 @@ from aesara import config, printing
from aesara import scalar as aes from aesara import scalar as aes
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import Generic from aesara.graph.type import Generic
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import pprint
from aesara.scalar.basic import BinaryScalarOp from aesara.scalar.basic import BinaryScalarOp
......
...@@ -24,8 +24,9 @@ from aesara import scalar as aes ...@@ -24,8 +24,9 @@ from aesara import scalar as aes
from aesara.compile import optdb from aesara.compile import optdb
from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer
from aesara.link.c.op import COp
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp from aesara.scalar import UnaryScalarOp
from aesara.tensor import basic as at from aesara.tensor import basic as at
......
...@@ -24,7 +24,7 @@ except ImportError: ...@@ -24,7 +24,7 @@ except ImportError:
import aesara import aesara
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp from aesara.link.c.op import OpenMPOp
from aesara.tensor import blas from aesara.tensor import blas
from aesara.tensor.basic import ( from aesara.tensor.basic import (
as_tensor_variable, as_tensor_variable,
......
...@@ -5,9 +5,10 @@ from typing import Optional ...@@ -5,9 +5,10 @@ from typing import Optional
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp, _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList from aesara.graph.type import EnumList
from aesara.link.c.op import OpenMPOp
from aesara.scalar import int8, int64 from aesara.scalar import int8, int64
from aesara.tensor import blas_headers from aesara.tensor import blas_headers
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
......
...@@ -5,9 +5,10 @@ from typing import Optional ...@@ -5,9 +5,10 @@ from typing import Optional
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp, _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList from aesara.graph.type import EnumList
from aesara.link.c.op import OpenMPOp
from aesara.scalar import int64 from aesara.scalar import int64
from aesara.tensor import blas_headers from aesara.tensor import blas_headers
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
......
...@@ -5,9 +5,9 @@ import aesara.tensor as at ...@@ -5,9 +5,9 @@ import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad_undefined from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import ExternalCOp, OpenMPOp
from aesara.graph.opt import local_optimizer from aesara.graph.opt import local_optimizer
from aesara.link.c.cmodule import GCC_compiler from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot from aesara.tensor.blas import batched_dot
from aesara.tensor.extra_ops import cpu_contiguous from aesara.tensor.extra_ops import cpu_contiguous
......
...@@ -7,8 +7,8 @@ import numpy as np ...@@ -7,8 +7,8 @@ import numpy as np
import aesara import aesara
from aesara.gradient import grad_not_implemented, grad_undefined from aesara.gradient import grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp
from aesara.graph.type import EnumList from aesara.graph.type import EnumList
from aesara.link.c.op import COp
from aesara.tensor.basic import arange, as_tensor_variable, concatenate, stack, zeros from aesara.tensor.basic import arange, as_tensor_variable, concatenate, stack, zeros
from aesara.tensor.math import ceil_intdiv from aesara.tensor.math import ceil_intdiv
from aesara.tensor.subtensor import inc_subtensor, set_subtensor from aesara.tensor.subtensor import inc_subtensor, set_subtensor
......
...@@ -7,8 +7,8 @@ import numpy as np ...@@ -7,8 +7,8 @@ import numpy as np
import aesara import aesara
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.tensor import _get_vector_length from aesara.tensor import _get_vector_length
......
...@@ -12,10 +12,10 @@ import aesara.tensor.basic as at ...@@ -12,10 +12,10 @@ import aesara.tensor.basic as at
import aesara.tensor.math as tm import aesara.tensor.math as tm
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import OpenMPOp
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList from aesara.graph.type import EnumList
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import OpenMPOp
from aesara.scalar import bool as bool_t from aesara.scalar import bool as bool_t
from aesara.tensor.type import TensorType, int_dtypes from aesara.tensor.type import TensorType, int_dtypes
......
...@@ -11,10 +11,11 @@ from aesara import scalar as aes ...@@ -11,10 +11,11 @@ from aesara import scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
......
...@@ -4,7 +4,8 @@ import aesara.tensor as at ...@@ -4,7 +4,8 @@ import aesara.tensor as at
from aesara.compile.debugmode import _lessbroken_deepcopy from aesara.compile.debugmode import _lessbroken_deepcopy
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.tensor.type import scalar from aesara.tensor.type import scalar
from aesara.tensor.type_other import SliceType from aesara.tensor.type_other import SliceType
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
......
...@@ -475,7 +475,7 @@ storage with the right shape and number of dimensions. ...@@ -475,7 +475,7 @@ storage with the right shape and number of dimensions.
import numpy import numpy
import aesara import aesara
from aesara.graph.op import COp from aesara.link.c.op import COp
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
...@@ -745,7 +745,7 @@ The new :class:`Op` is defined inside a Python file with the following code : ...@@ -745,7 +745,7 @@ The new :class:`Op` is defined inside a Python file with the following code :
.. testcode:: .. testcode::
import aesara import aesara
from aesara.graph.op import ExternalCOp from aesara.link.c.op import ExternalCOp
class VectorTimesVector(ExternalCOp): class VectorTimesVector(ExternalCOp):
__props__ = () __props__ = ()
......
...@@ -168,7 +168,7 @@ To allow consistent interface of Ops that support OpenMP, we have some ...@@ -168,7 +168,7 @@ To allow consistent interface of Ops that support OpenMP, we have some
helper code. Doing this also allows to enable/disable OpenMP globally helper code. Doing this also allows to enable/disable OpenMP globally
or per op for fine-grained control. or per op for fine-grained control.
Your Op needs to inherit from ``aesara.graph.op.OpenMPOp``. If it overrides Your Op needs to inherit from ``aesara.link.c.op.OpenMPOp``. If it overrides
the ``__init__()`` method, it must have an ``openmp=None`` parameter the ``__init__()`` method, it must have an ``openmp=None`` parameter
and must call ``super(MyOpClass, self).__init__(openmp=openmp)``. and must call ``super(MyOpClass, self).__init__(openmp=openmp)``.
......
...@@ -139,8 +139,8 @@ the params type. ...@@ -139,8 +139,8 @@ the params type.
.. testcode:: .. testcode::
from aesara.graph.op import COp
from aesara.graph.type import Generic from aesara.graph.type import Generic
from aesara.link.c.op import COp
from aesara.scalar import as_scalar from aesara.scalar import as_scalar
class MulOp(COp): class MulOp(COp):
......
...@@ -151,6 +151,10 @@ check_untyped_defs = False ...@@ -151,6 +151,10 @@ check_untyped_defs = False
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
[mypy-aesara.link.c.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils] [mypy-aesara.link.utils]
ignore_errors = True ignore_errors = True
check_untyped_defs = False check_untyped_defs = False
......
...@@ -17,9 +17,10 @@ from aesara.compile.mode import predefined_modes ...@@ -17,9 +17,10 @@ from aesara.compile.mode import predefined_modes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB from aesara.graph.optdb import EquilibriumDB
from aesara.link.c.op import COp
from aesara.tensor.math import add, dot, log from aesara.tensor.math import add, dot, log
from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
...@@ -9,8 +9,9 @@ from aesara import scalar as aes ...@@ -9,8 +9,9 @@ from aesara import scalar as aes
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import utils from aesara.graph import utils
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.tensor.math import _allclose, dot from aesara.tensor.math import _allclose, dot
from aesara.tensor.type import fmatrix, iscalar, matrix, vector from aesara.tensor.type import fmatrix, iscalar, matrix, vector
......
...@@ -4,13 +4,12 @@ import pytest ...@@ -4,13 +4,12 @@ import pytest
import aesara import aesara
import aesara.graph.op as op import aesara.graph.op as op
import aesara.tensor as at import aesara.tensor as at
from aesara import scalar as aes
from aesara import shared from aesara import shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op from aesara.graph.op import Op
from aesara.graph.type import Generic, Type from aesara.graph.type import Generic, Type
from aesara.graph.utils import MethodNotDefined, TestValueError from aesara.graph.utils import TestValueError
from aesara.tensor.math import log from aesara.tensor.math import log
from aesara.tensor.type import dmatrix, dscalar, dvector, vector from aesara.tensor.type import dmatrix, dscalar, dvector, vector
...@@ -82,38 +81,6 @@ class NoInputOp(Op): ...@@ -82,38 +81,6 @@ class NoInputOp(Op):
output_storage[0][0] = "test Op no input" output_storage[0][0] = "test Op no input"
class StructOp(COp):
__props__ = ()
def do_constant_folding(self, fgraph, node):
# we are not constant
return False
# The input only serves to distinguish thunks
def make_node(self, i):
return Apply(self, [i], [aes.uint64()])
def c_support_code_struct(self, node, name):
return f"npy_uint64 counter{name};"
def c_init_code_struct(self, node, name, sub):
return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub):
return """
%(out)s = counter%(name)s;
counter%(name)s++;
""" % dict(
out=outputs_names[0], name=name
)
def c_code_cache_version(self):
return (1,)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
class TestOp: class TestOp:
# Sanity tests # Sanity tests
...@@ -141,106 +108,8 @@ class TestOp: ...@@ -141,106 +108,8 @@ class TestOp:
rval = f() rval = f()
assert rval == "test Op no input" assert rval == "test Op no input"
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_op_struct(self):
sop = StructOp()
c = sop(aesara.tensor.constant(0))
mode = None
if config.mode == "FAST_COMPILE":
mode = "FAST_RUN"
f = aesara.function([], c, mode=mode)
rval = f()
assert rval == 0
rval = f()
assert rval == 1
c2 = sop(aesara.tensor.constant(1))
f2 = aesara.function([], [c, c2], mode=mode)
rval = f2()
assert rval == [0, 0]
class TestMakeThunk: class TestMakeThunk:
def test_no_c_code(self):
class IncOnePython(COp):
"""An Op with only a Python (perform) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
(input,) = inputs
(output,) = outputs
output[0] = input + 1
i = aes.int32("i")
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises(NotImplementedError):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
def test_no_perform(self):
class IncOneC(COp):
"""An Op with only a C (c_code) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {x} + 1;"
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
i = aes.int32("i")
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
if config.cxx:
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
def test_no_make_node(self): def test_no_make_node(self):
class DoubleOp(Op): class DoubleOp(Op):
"""An Op without make_node""" """An Op without make_node"""
......
...@@ -4,9 +4,9 @@ import pytest ...@@ -4,9 +4,9 @@ import pytest
import aesara import aesara
from aesara import tensor as at from aesara import tensor as at
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, ExternalCOp
from aesara.graph.params_type import Params, ParamsType from aesara.graph.params_type import Params, ParamsType
from aesara.graph.type import EnumList, Generic from aesara.graph.type import EnumList, Generic
from aesara.link.c.op import COp, ExternalCOp
from aesara.scalar import Scalar from aesara.scalar import Scalar
from aesara.tensor.type import TensorType, matrix from aesara.tensor.type import TensorType, matrix
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
...@@ -6,8 +6,8 @@ import pytest ...@@ -6,8 +6,8 @@ import pytest
import aesara import aesara
from aesara import scalar as aes from aesara import scalar as aes
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp
from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType, Type from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType, Type
from aesara.link.c.op import COp
from aesara.tensor.type import TensorType, continuous_dtypes from aesara.tensor.type import TensorType, continuous_dtypes
......
...@@ -7,10 +7,10 @@ from aesara.compile.mode import Mode ...@@ -7,10 +7,10 @@ from aesara.compile.mode import Mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp
from aesara.graph.type import CType from aesara.graph.type import CType
from aesara.link.basic import PerformLinker from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker
from aesara.link.c.op import COp
from aesara.tensor.type import iscalar, matrix, vector from aesara.tensor.type import iscalar, matrix, vector
from tests.link.test_link import make_function from tests.link.test_link import make_function
......
import numpy as np
import pytest
import aesara
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import COp
class StructOp(COp):
__props__ = ()
def do_constant_folding(self, fgraph, node):
# we are not constant
return False
# The input only serves to distinguish thunks
def make_node(self, i):
return Apply(self, [i], [aes.uint64()])
def c_support_code_struct(self, node, name):
return f"npy_uint64 counter{name};"
def c_init_code_struct(self, node, name, sub):
return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub):
return """
%(out)s = counter%(name)s;
counter%(name)s++;
""" % dict(
out=outputs_names[0], name=name
)
def c_code_cache_version(self):
return (1,)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
class TestCOp:
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_op_struct(self):
sop = StructOp()
c = sop(aesara.tensor.constant(0))
mode = None
if config.mode == "FAST_COMPILE":
mode = "FAST_RUN"
f = aesara.function([], c, mode=mode)
rval = f()
assert rval == 0
rval = f()
assert rval == 1
c2 = sop(aesara.tensor.constant(1))
f2 = aesara.function([], [c, c2], mode=mode)
rval = f2()
assert rval == [0, 0]
class TestMakeThunk:
def test_no_c_code(self):
class IncOnePython(COp):
"""An Op with only a Python (perform) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
(input,) = inputs
(output,) = outputs
output[0] = input + 1
i = aes.int32("i")
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises(NotImplementedError):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
def test_no_perform(self):
class IncOneC(COp):
"""An Op with only a C (c_code) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {x} + 1;"
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
i = aes.int32("i")
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
if config.cxx:
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论