提交 6d3c7568 authored 作者: Purna Chandra Mansingh's avatar Purna Chandra Mansingh 提交者: Ricardo Vieira

Rename module pytensor.tensor.var to pytensor.tensor.variable

上级 288a3f34
...@@ -16,7 +16,7 @@ repos: ...@@ -16,7 +16,7 @@ repos:
pytensor/graph/op\.py| pytensor/graph/op\.py|
pytensor/compile/nanguardmode\.py| pytensor/compile/nanguardmode\.py|
pytensor/graph/rewriting/basic\.py| pytensor/graph/rewriting/basic\.py|
pytensor/tensor/var\.py| pytensor/tensor/variable\.py|
)$ )$
- id: check-merge-conflict - id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
......
...@@ -30,7 +30,7 @@ if TYPE_CHECKING: ...@@ -30,7 +30,7 @@ if TYPE_CHECKING:
OutputStorageType, OutputStorageType,
StorageMapType, StorageMapType,
) )
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
ThunkAndContainersType = Tuple["BasicThunkType", List["Container"], List["Container"]] ThunkAndContainersType = Tuple["BasicThunkType", List["Container"], List["Container"]]
......
...@@ -33,7 +33,7 @@ slice length. ...@@ -33,7 +33,7 @@ slice length.
def subtensor_assert_indices_jax_compatible(node, idx_list): def subtensor_assert_indices_jax_compatible(node, idx_list):
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
ilist = indices_from_subtensor(node.inputs[1:], idx_list) ilist = indices_from_subtensor(node.inputs[1:], idx_list)
for idx in ilist: for idx in ilist:
......
...@@ -82,7 +82,7 @@ from pytensor.tensor.basic import as_tensor_variable ...@@ -82,7 +82,7 @@ from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import minimum from pytensor.tensor.math import minimum
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
# Logging function for sending warning or info # Logging function for sending warning or info
......
...@@ -69,7 +69,7 @@ from pytensor.tensor.subtensor import ( ...@@ -69,7 +69,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements, get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.var import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
list_opt_slice = [ list_opt_slice = [
......
...@@ -21,7 +21,7 @@ from pytensor.graph.type import HasDataType ...@@ -21,7 +21,7 @@ from pytensor.graph.type import HasDataType
from pytensor.graph.utils import TestValueError from pytensor.graph.utils import TestValueError
from pytensor.tensor.basic import AllocEmpty, cast from pytensor.tensor.basic import AllocEmpty, cast
from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.subtensor import set_subtensor
from pytensor.tensor.var import TensorConstant from pytensor.tensor.variable import TensorConstant
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -51,7 +51,11 @@ from pytensor.tensor.type import TensorType ...@@ -51,7 +51,11 @@ from pytensor.tensor.type import TensorType
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
from pytensor.tensor.type import iscalar, ivector, scalar, tensor, vector from pytensor.tensor.type import iscalar, ivector, scalar, tensor, vector
from pytensor.tensor.var import TensorConstant, TensorVariable, _tensor_py_operators from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
_tensor_py_operators,
)
sparse_formats = ["csc", "csr"] sparse_formats = ["csc", "csr"]
......
...@@ -146,7 +146,7 @@ from pytensor.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk ...@@ -146,7 +146,7 @@ from pytensor.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk
from pytensor.tensor.subtensor import * # noqa from pytensor.tensor.subtensor import * # noqa
from pytensor.tensor.type import * # noqa from pytensor.tensor.type import * # noqa
from pytensor.tensor.type_other import * # noqa from pytensor.tensor.type_other import * # noqa
from pytensor.tensor.var import TensorConstant, TensorVariable # noqa from pytensor.tensor.variable import TensorConstant, TensorVariable # noqa
# Allow accessing numpy constants from pytensor.tensor # Allow accessing numpy constants from pytensor.tensor
from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi # noqa from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi # noqa
......
...@@ -62,7 +62,7 @@ from pytensor.tensor.type import ( ...@@ -62,7 +62,7 @@ from pytensor.tensor.type import (
uint_dtypes, uint_dtypes,
values_eq_approx_always_true, values_eq_approx_always_true,
) )
from pytensor.tensor.var import ( from pytensor.tensor.variable import (
TensorConstant, TensorConstant,
TensorVariable, TensorVariable,
get_unique_constant_value, get_unique_constant_value,
......
...@@ -29,7 +29,7 @@ from pytensor.tensor.basic import ( ...@@ -29,7 +29,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -29,7 +29,7 @@ from pytensor.tensor.type import ( ...@@ -29,7 +29,7 @@ from pytensor.tensor.type import (
float_dtypes, float_dtypes,
lvector, lvector,
) )
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq from pytensor.utils import uniq
......
...@@ -35,7 +35,7 @@ from pytensor.tensor.math import sum as at_sum ...@@ -35,7 +35,7 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import switch from pytensor.tensor.math import switch
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
......
...@@ -16,7 +16,7 @@ from pytensor.tensor.math import exp, lt, outer, tensordot ...@@ -16,7 +16,7 @@ from pytensor.tensor.math import exp, lt, outer, tensordot
from pytensor.tensor.shape import shape from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.subtensor import set_subtensor
from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.var import TensorConstant from pytensor.tensor.variable import TensorConstant
class Fourier(Op): class Fourier(Op):
......
...@@ -40,7 +40,7 @@ from pytensor.tensor.type import ( ...@@ -40,7 +40,7 @@ from pytensor.tensor.type import (
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import as_list from pytensor.tensor.utils import as_list
from pytensor.tensor.var import TensorConstant, _tensor_py_operators from pytensor.tensor.variable import TensorConstant, _tensor_py_operators
if TYPE_CHECKING: if TYPE_CHECKING:
......
import typing
from functools import partial from functools import partial
from typing import Tuple from typing import Callable, Tuple
import numpy as np import numpy as np
...@@ -299,7 +300,7 @@ class Eigh(Eig): ...@@ -299,7 +300,7 @@ class Eigh(Eig):
""" """
_numop = staticmethod(np.linalg.eigh) _numop = typing.cast(Callable, staticmethod(np.linalg.eigh))
__props__ = ("UPLO",) __props__ = ("UPLO",)
def __init__(self, UPLO="L"): def __init__(self, UPLO="L"):
......
...@@ -21,7 +21,7 @@ from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_ ...@@ -21,7 +21,7 @@ from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
class RandomVariable(Op): class RandomVariable(Op):
......
...@@ -14,7 +14,7 @@ from pytensor.tensor.extra_ops import broadcast_to ...@@ -14,7 +14,7 @@ from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import int_dtypes from pytensor.tensor.type import int_dtypes
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -72,7 +72,7 @@ from pytensor.tensor.math import eq ...@@ -72,7 +72,7 @@ from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter from pytensor.utils import NoDuplicateOptWarningFilter
......
...@@ -39,7 +39,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -39,7 +39,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
) )
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
class InplaceElemwiseOptimizer(GraphRewriter): class InplaceElemwiseOptimizer(GraphRewriter):
......
...@@ -6,7 +6,7 @@ from pytensor.tensor.elemwise import DimShuffle ...@@ -6,7 +6,7 @@ from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum from pytensor.tensor.math import Sum
from pytensor.tensor.shape import Reshape from pytensor.tensor.shape import Reshape
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
@node_rewriter([AdvancedIncSubtensor]) @node_rewriter([AdvancedIncSubtensor])
......
...@@ -101,7 +101,7 @@ from pytensor.tensor.type import ( ...@@ -101,7 +101,7 @@ from pytensor.tensor.type import (
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
from pytensor.tensor.var import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
......
...@@ -81,7 +81,7 @@ from pytensor.tensor.subtensor import ( ...@@ -81,7 +81,7 @@ from pytensor.tensor.subtensor import (
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
def register_useless(lopt, *tags, **kwargs): def register_useless(lopt, *tags, **kwargs):
......
...@@ -19,7 +19,7 @@ from pytensor.tensor import get_vector_length ...@@ -19,7 +19,7 @@ from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
ShapeValueType = Union[None, np.integer, int, Variable] ShapeValueType = Union[None, np.integer, int, Variable]
......
...@@ -6,7 +6,7 @@ from pytensor.compile import SharedVariable, shared_constructor ...@@ -6,7 +6,7 @@ from pytensor.compile import SharedVariable, shared_constructor
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.var import _tensor_py_operators from pytensor.tensor.variable import _tensor_py_operators
def load_shared_variable(val): def load_shared_variable(val):
......
...@@ -16,7 +16,7 @@ from pytensor.tensor import math as atm ...@@ -16,7 +16,7 @@ from pytensor.tensor import math as atm
from pytensor.tensor.nlinalg import matrix_dot from pytensor.tensor.nlinalg import matrix_dot
from pytensor.tensor.shape import reshape from pytensor.tensor.shape import reshape
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -18,7 +18,7 @@ from pytensor.utils import apply_across_args ...@@ -18,7 +18,7 @@ from pytensor.utils import apply_across_args
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
_logger = logging.getLogger("pytensor.tensor.type") _logger = logging.getLogger("pytensor.tensor.type")
......
差异被折叠。
差异被折叠。
...@@ -8,7 +8,7 @@ from pytensor.graph.op import Op ...@@ -8,7 +8,7 @@ from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.tensor.type import scalar from pytensor.tensor.type import scalar
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.typed_list.type import TypedListType from pytensor.typed_list.type import TypedListType
......
...@@ -32,4 +32,5 @@ pytensor/tensor/slinalg.py ...@@ -32,4 +32,5 @@ pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py pytensor/tensor/subtensor.py
pytensor/tensor/type.py pytensor/tensor/type.py
pytensor/tensor/type_other.py pytensor/tensor/type_other.py
pytensor/tensor/var.py pytensor/tensor/variable.py
\ No newline at end of file pytensor/tensor/nlinalg.py
\ No newline at end of file
...@@ -33,7 +33,7 @@ from pytensor.graph.type import Type ...@@ -33,7 +33,7 @@ from pytensor.graph.type import Type
from pytensor.tensor.math import max_and_argmax from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from tests.graph.utils import MyInnerGraphOp from tests.graph.utils import MyInnerGraphOp
......
...@@ -18,7 +18,7 @@ from pytensor.link.utils import map_storage ...@@ -18,7 +18,7 @@ from pytensor.link.utils import map_storage
from pytensor.link.vm import VM, Loop, Stack, VMLinker from pytensor.link.vm import VM, Loop, Stack, VMLinker
from pytensor.tensor.math import cosh, tanh from pytensor.tensor.math import cosh, tanh
from pytensor.tensor.type import lscalar, scalar, scalars, vector, vectors from pytensor.tensor.type import lscalar, scalar, scalars, vector, vectors
from pytensor.tensor.var import TensorConstant from pytensor.tensor.variable import TensorConstant
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
...@@ -125,7 +125,7 @@ from pytensor.tensor.type import ( ...@@ -125,7 +125,7 @@ from pytensor.tensor.type import (
vectors, vectors,
zscalar, zscalar,
) )
from pytensor.tensor.var import TensorConstant from pytensor.tensor.variable import TensorConstant
from tests import unittest_tools as utt from tests import unittest_tools as utt
......
...@@ -125,7 +125,7 @@ from pytensor.tensor.type import ( ...@@ -125,7 +125,7 @@ from pytensor.tensor.type import (
vectors, vectors,
wvector, wvector,
) )
from pytensor.tensor.var import TensorConstant from pytensor.tensor.variable import TensorConstant
from pytensor.utils import PYTHON_INT_BITWIDTH from pytensor.utils import PYTHON_INT_BITWIDTH
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import ( from tests.tensor.utils import (
......
...@@ -47,7 +47,7 @@ from pytensor.tensor.type import ( ...@@ -47,7 +47,7 @@ from pytensor.tensor.type import (
vector, vector,
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.typed_list import make_list from pytensor.typed_list import make_list
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyType2 from tests.graph.utils import MyType2
......
...@@ -29,7 +29,7 @@ from pytensor.tensor.type import ( ...@@ -29,7 +29,7 @@ from pytensor.tensor.type import (
tensor3, tensor3,
) )
from pytensor.tensor.type_other import MakeSlice, NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.var import ( from pytensor.tensor.variable import (
DenseTensorConstant, DenseTensorConstant,
DenseTensorVariable, DenseTensorVariable,
TensorConstant, TensorConstant,
...@@ -405,3 +405,15 @@ class TestTensorInstanceMethods: ...@@ -405,3 +405,15 @@ class TestTensorInstanceMethods:
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing # Test equivalent advanced indexing
assert_array_equal(X[:, indices].eval({X: x}), x[:, indices]) assert_array_equal(X[:, indices].eval({X: x}), x[:, indices])
def test_deprecated_import():
with pytest.warns(
DeprecationWarning,
match="The module 'pytensor.tensor.var' has been deprecated.",
):
import pytensor.tensor.var as _var
# Make sure the deprecated import provides access to 'variable' module
assert hasattr(_var, "TensorVariable")
assert hasattr(_var, "TensorConstant")
...@@ -13,7 +13,7 @@ from pytensor.tensor.type import ( ...@@ -13,7 +13,7 @@ from pytensor.tensor.type import (
vector, vector,
) )
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
from pytensor.tensor.var import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.typed_list.basic import ( from pytensor.typed_list.basic import (
Append, Append,
Count, Count,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论