提交 8a4505c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move C-specific content from aesara.graph.op to aesara.link.c.op

上级 4eea9f79
......@@ -32,9 +32,10 @@ from aesara.configdefaults import config
from aesara.graph.basic import Variable, io_toposort
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, HasInnerGraph, Op
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import InconsistencyError, MethodNotDefined
from aesara.link.basic import Container, LocalLinker
from aesara.link.c.op import COp
from aesara.link.utils import map_storage, raise_with_op
from aesara.printing import _debugprint
from aesara.utils import NoDuplicateOptWarningFilter, difference, get_unbound_function
......
......@@ -11,8 +11,9 @@ import warnings
from typing import Dict, Tuple
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.type import CType
from aesara.link.c.op import COp
def register_view_op_c_code(type, code, version=()):
......
......@@ -11,12 +11,13 @@ import aesara.tensor as at
from aesara.configdefaults import config
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, ExternalCOp, Op, _NoPythonOp
from aesara.graph.op import Op, _NoPythonOp
from aesara.graph.opt import copy_stack_trace
from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.interface import HideC
from aesara.link.c.op import COp, ExternalCOp
from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int32_t
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable
......
......@@ -10,9 +10,9 @@ from aesara.gpuarray.basic_ops import (
)
from aesara.gpuarray.opt_util import inplace_allocempty
from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonCOp
from aesara.graph.opt import LocalOptGroup, in2out
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import _NoPythonCOp
from aesara.scalar import bool as bool_t
from aesara.tensor.basic import as_tensor_variable
......
......@@ -11,8 +11,8 @@ from aesara.gpuarray.basic_ops import (
from aesara.gpuarray.type import gpu_context_type
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonExternalCOp
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import _NoPythonExternalCOp
from aesara.scalar import bool as bool_t
from aesara.tensor import as_tensor_variable
from aesara.tensor.type import discrete_dtypes
......
......@@ -13,8 +13,8 @@ from aesara.gpuarray.elemwise import GpuDimShuffle
from aesara.gpuarray.type import GpuArrayType, gpu_context_type
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.op import _NoPythonExternalCOp
from aesara.graph.opt import local_optimizer
from aesara.link.c.op import _NoPythonExternalCOp
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot
......
......@@ -28,10 +28,10 @@ from aesara.gpuarray.basic_ops import (
from aesara.gpuarray.type import GpuArraySharedVariable, get_context, gpu_context_type
from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import ExternalCOp, _NoPythonCOp, _NoPythonExternalCOp
from aesara.graph.params_type import ParamsType
from aesara.graph.type import CDataType, EnumList, Generic
from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, _NoPythonCOp, _NoPythonExternalCOp
from aesara.raise_op import Assert
from aesara.scalar import as_scalar
from aesara.scalar import bool as bool_t
......
......@@ -14,8 +14,9 @@ from aesara.gpuarray.basic_ops import (
)
from aesara.gpuarray.type import GpuArrayType, gpu_context_type
from aesara.graph.basic import Apply
from aesara.graph.op import ExternalCOp, Op
from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import ExternalCOp
from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at
from aesara.tensor import math as tm
......
from aesara.graph.basic import Apply
from aesara.graph.op import COp
from aesara.graph.type import Generic
from aesara.link.c.op import COp
from .basic_ops import as_gpuarray_variable, gpuarray_helper_inc_dir, infer_context_name
from .type import GpuArrayType
......
......@@ -5,10 +5,11 @@ import numpy as np
import aesara.tensor as at
from aesara.gradient import grad_not_implemented
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType
from aesara.link.c.interface import HideC
from aesara.link.c.op import COp
from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int_t
from aesara.scalar import uint32 as size_t
......
差异被折叠。
差异被折叠。
......@@ -7,9 +7,9 @@ import numpy as np
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Generic
from aesara.link.c.op import COp
class ExceptionType(Generic):
......
......@@ -7,7 +7,7 @@ import numpy as np
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import COp
from aesara.link.c.op import COp
from aesara.scalar import Scalar, as_scalar
from aesara.tensor.type import discrete_dtypes
......
......@@ -25,9 +25,9 @@ from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.gradient import undefined_grad
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.opt import in2out, local_optimizer
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp, Op
from aesara.sandbox import multinomial
from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int_t
......
......@@ -27,10 +27,10 @@ from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp
from aesara.graph.opt import MergeOptimizer
from aesara.graph.type import CType
from aesara.graph.utils import MetaObject, MethodNotDefined
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.utils import (
......
......@@ -18,7 +18,8 @@ from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.sparse.type import SparseType, _is_sparse
from aesara.sparse.utils import hash_from_sparse
......
......@@ -5,8 +5,8 @@ import aesara
import aesara.scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import COp, _NoPythonCOp
from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer
from aesara.link.c.op import COp, _NoPythonCOp
from aesara.misc.safe_asarray import _asarray
from aesara.sparse import basic as sparse
from aesara.sparse.basic import (
......
......@@ -23,10 +23,11 @@ from aesara import scalar as aes
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint
from aesara.raise_op import CheckAndRaise, assert_op
......
......@@ -147,7 +147,7 @@ from aesara.compile.mode import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, view_roots
from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.opt import (
EquilibriumOptimizer,
GlobalOptimizer,
......@@ -158,6 +158,7 @@ from aesara.graph.opt import (
from aesara.graph.optdb import SequenceDB
from aesara.graph.params_type import ParamsType
from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from aesara.link.c.op import COp
from aesara.printing import FunctionPrinter, debugprint, pprint
from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at
......
from aesara.configdefaults import config
from aesara.graph.op import COp
from aesara.graph.opt import in2out
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp
from aesara.scalar import bool as bool_t
from aesara.tensor import basic as at
from aesara.tensor.blas import (
......
......@@ -8,10 +8,10 @@ from aesara.configdefaults import config
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply
from aesara.graph.null_type import NullType
from aesara.graph.op import COp, ExternalCOp, OpenMPOp
from aesara.graph.params_type import ParamsType
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.basic import failure_code
from aesara.link.c.op import COp, ExternalCOp, OpenMPOp
from aesara.misc.frozendict import frozendict
from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, Printer, pprint
......
......@@ -11,9 +11,10 @@ from aesara.gradient import (
grad_undefined,
)
from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList, Generic
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t
......
......@@ -7,9 +7,10 @@ from aesara import config, printing
from aesara import scalar as aes
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Generic
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.scalar.basic import BinaryScalarOp
......
......@@ -24,8 +24,9 @@ from aesara import scalar as aes
from aesara.compile import optdb
from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer
from aesara.link.c.op import COp
from aesara.raise_op import Assert
from aesara.scalar import UnaryScalarOp
from aesara.tensor import basic as at
......
......@@ -24,7 +24,7 @@ except ImportError:
import aesara
from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp
from aesara.link.c.op import OpenMPOp
from aesara.tensor import blas
from aesara.tensor.basic import (
as_tensor_variable,
......
......@@ -5,9 +5,10 @@ from typing import Optional
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp, _NoPythonOp
from aesara.graph.op import _NoPythonOp
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList
from aesara.link.c.op import OpenMPOp
from aesara.scalar import int8, int64
from aesara.tensor import blas_headers
from aesara.tensor.basic import as_tensor_variable
......
......@@ -5,9 +5,10 @@ from typing import Optional
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.op import OpenMPOp, _NoPythonOp
from aesara.graph.op import _NoPythonOp
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList
from aesara.link.c.op import OpenMPOp
from aesara.scalar import int64
from aesara.tensor import blas_headers
from aesara.tensor.basic import as_tensor_variable
......
......@@ -5,9 +5,9 @@ import aesara.tensor as at
from aesara.configdefaults import config
from aesara.gradient import grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.op import ExternalCOp, OpenMPOp
from aesara.graph.opt import local_optimizer
from aesara.link.c.cmodule import GCC_compiler
from aesara.link.c.op import ExternalCOp, OpenMPOp
from aesara.tensor.basic_opt import register_canonicalize
from aesara.tensor.blas import batched_dot
from aesara.tensor.extra_ops import cpu_contiguous
......
......@@ -7,8 +7,8 @@ import numpy as np
import aesara
from aesara.gradient import grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply
from aesara.graph.op import COp
from aesara.graph.type import EnumList
from aesara.link.c.op import COp
from aesara.tensor.basic import arange, as_tensor_variable, concatenate, stack, zeros
from aesara.tensor.math import ceil_intdiv
from aesara.tensor.subtensor import inc_subtensor, set_subtensor
......
......@@ -7,8 +7,8 @@ import numpy as np
import aesara
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp
from aesara.graph.params_type import ParamsType
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32
from aesara.tensor import _get_vector_length
......
......@@ -12,10 +12,10 @@ import aesara.tensor.basic as at
import aesara.tensor.math as tm
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import OpenMPOp
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import OpenMPOp
from aesara.scalar import bool as bool_t
from aesara.tensor.type import TensorType, int_dtypes
......
......@@ -11,10 +11,11 @@ from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import COp
from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant
......
......@@ -4,7 +4,8 @@ import aesara.tensor as at
from aesara.compile.debugmode import _lessbroken_deepcopy
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.link.c.op import COp
from aesara.tensor.type import scalar
from aesara.tensor.type_other import SliceType
from aesara.tensor.var import TensorVariable
......
......@@ -475,7 +475,7 @@ storage with the right shape and number of dimensions.
import numpy
import aesara
from aesara.graph.op import COp
from aesara.link.c.op import COp
from aesara.graph.basic import Apply
......@@ -745,7 +745,7 @@ The new :class:`Op` is defined inside a Python file with the following code :
.. testcode::
import aesara
from aesara.graph.op import ExternalCOp
from aesara.link.c.op import ExternalCOp
class VectorTimesVector(ExternalCOp):
__props__ = ()
......
......@@ -168,7 +168,7 @@ To allow consistent interface of Ops that support OpenMP, we have some
helper code. Doing this also allows to enable/disable OpenMP globally
or per op for fine-grained control.
Your Op needs to inherit from ``aesara.graph.op.OpenMPOp``. If it overrides
Your Op needs to inherit from ``aesara.link.c.op.OpenMPOp``. If it overrides
the ``__init__()`` method, it must have an ``openmp=None`` parameter
and must call ``super(MyOpClass, self).__init__(openmp=openmp)``.
......
......@@ -139,8 +139,8 @@ the params type.
.. testcode::
from aesara.graph.op import COp
from aesara.graph.type import Generic
from aesara.link.c.op import COp
from aesara.scalar import as_scalar
class MulOp(COp):
......
......@@ -151,6 +151,10 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.c.op]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.link.utils]
ignore_errors = True
check_untyped_defs = False
......
......@@ -17,9 +17,10 @@ from aesara.compile.mode import predefined_modes
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import BadOptimization
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.opt import local_optimizer
from aesara.graph.optdb import EquilibriumDB
from aesara.link.c.op import COp
from aesara.tensor.math import add, dot, log
from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, vector
from tests import unittest_tools as utt
......
......@@ -9,8 +9,9 @@ from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph import utils
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.type import Type
from aesara.link.c.op import COp
from aesara.tensor.math import _allclose, dot
from aesara.tensor.type import fmatrix, iscalar, matrix, vector
......
......@@ -4,13 +4,12 @@ import pytest
import aesara
import aesara.graph.op as op
import aesara.tensor as at
from aesara import scalar as aes
from aesara import shared
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp, Op
from aesara.graph.op import Op
from aesara.graph.type import Generic, Type
from aesara.graph.utils import MethodNotDefined, TestValueError
from aesara.graph.utils import TestValueError
from aesara.tensor.math import log
from aesara.tensor.type import dmatrix, dscalar, dvector, vector
......@@ -82,38 +81,6 @@ class NoInputOp(Op):
output_storage[0][0] = "test Op no input"
class StructOp(COp):
__props__ = ()
def do_constant_folding(self, fgraph, node):
# we are not constant
return False
# The input only serves to distinguish thunks
def make_node(self, i):
return Apply(self, [i], [aes.uint64()])
def c_support_code_struct(self, node, name):
return f"npy_uint64 counter{name};"
def c_init_code_struct(self, node, name, sub):
return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub):
return """
%(out)s = counter%(name)s;
counter%(name)s++;
""" % dict(
out=outputs_names[0], name=name
)
def c_code_cache_version(self):
return (1,)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
class TestOp:
# Sanity tests
......@@ -141,106 +108,8 @@ class TestOp:
rval = f()
assert rval == "test Op no input"
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_op_struct(self):
sop = StructOp()
c = sop(aesara.tensor.constant(0))
mode = None
if config.mode == "FAST_COMPILE":
mode = "FAST_RUN"
f = aesara.function([], c, mode=mode)
rval = f()
assert rval == 0
rval = f()
assert rval == 1
c2 = sop(aesara.tensor.constant(1))
f2 = aesara.function([], [c, c2], mode=mode)
rval = f2()
assert rval == [0, 0]
class TestMakeThunk:
def test_no_c_code(self):
class IncOnePython(COp):
"""An Op with only a Python (perform) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
(input,) = inputs
(output,) = outputs
output[0] = input + 1
i = aes.int32("i")
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises(NotImplementedError):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
def test_no_perform(self):
class IncOneC(COp):
"""An Op with only a C (c_code) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {x} + 1;"
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
i = aes.int32("i")
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
if config.cxx:
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
def test_no_make_node(self):
class DoubleOp(Op):
"""An Op without make_node"""
......
......@@ -4,9 +4,9 @@ import pytest
import aesara
from aesara import tensor as at
from aesara.graph.basic import Apply
from aesara.graph.op import COp, ExternalCOp
from aesara.graph.params_type import Params, ParamsType
from aesara.graph.type import EnumList, Generic
from aesara.link.c.op import COp, ExternalCOp
from aesara.scalar import Scalar
from aesara.tensor.type import TensorType, matrix
from tests import unittest_tools as utt
......
......@@ -6,8 +6,8 @@ import pytest
import aesara
from aesara import scalar as aes
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp
from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType, Type
from aesara.link.c.op import COp
from aesara.tensor.type import TensorType, continuous_dtypes
......
......@@ -7,10 +7,10 @@ from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp
from aesara.graph.type import CType
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, DualLinker, OpWiseCLinker
from aesara.link.c.op import COp
from aesara.tensor.type import iscalar, matrix, vector
from tests.link.test_link import make_function
......
import numpy as np
import pytest
import aesara
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.utils import MethodNotDefined
from aesara.link.c.op import COp
class StructOp(COp):
__props__ = ()
def do_constant_folding(self, fgraph, node):
# we are not constant
return False
# The input only serves to distinguish thunks
def make_node(self, i):
return Apply(self, [i], [aes.uint64()])
def c_support_code_struct(self, node, name):
return f"npy_uint64 counter{name};"
def c_init_code_struct(self, node, name, sub):
return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub):
return """
%(out)s = counter%(name)s;
counter%(name)s++;
""" % dict(
out=outputs_names[0], name=name
)
def c_code_cache_version(self):
return (1,)
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
class TestCOp:
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_op_struct(self):
sop = StructOp()
c = sop(aesara.tensor.constant(0))
mode = None
if config.mode == "FAST_COMPILE":
mode = "FAST_RUN"
f = aesara.function([], c, mode=mode)
rval = f()
assert rval == 0
rval = f()
assert rval == 1
c2 = sop(aesara.tensor.constant(1))
f2 = aesara.function([], [c, c2], mode=mode)
rval = f2()
assert rval == [0, 0]
class TestMakeThunk:
def test_no_c_code(self):
class IncOnePython(COp):
"""An Op with only a Python (perform) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def perform(self, node, inputs, outputs):
(input,) = inputs
(output,) = outputs
output[0] = input + 1
i = aes.int32("i")
o = IncOnePython()(i)
# Check that the c_code function is not implemented
with pytest.raises(NotImplementedError):
o.owner.op.c_code(o.owner, "o", ["x"], "z", {"fail": ""})
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
def test_no_perform(self):
class IncOneC(COp):
"""An Op with only a C (c_code) implementation"""
__props__ = ()
def make_node(self, input):
input = aes.as_scalar(input)
output = input.type()
return Apply(self, [input], [output])
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
return f"{z} = {x} + 1;"
def perform(self, *args, **kwargs):
raise NotImplementedError("No Python implementation available.")
i = aes.int32("i")
o = IncOneC()(i)
# Check that the perform function is not implemented
with pytest.raises((NotImplementedError, MethodNotDefined)):
o.owner.op.perform(o.owner, 0, [None])
storage_map = {i: [np.int32(3)], o: [None]}
compute_map = {i: [True], o: [False]}
thunk = o.owner.op.make_thunk(
o.owner, storage_map, compute_map, no_recycling=[]
)
if config.cxx:
required = thunk()
# Check everything went OK
assert not required # We provided all inputs
assert compute_map[o][0]
assert storage_map[o][0] == 4
else:
with pytest.raises((NotImplementedError, MethodNotDefined)):
thunk()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论