提交 60e25109 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Michael Osthege

Make Constant and Shared variables subclasses of the respective Variables

上级 97317a50
...@@ -479,7 +479,7 @@ class SparseConstantSignature(tuple): ...@@ -479,7 +479,7 @@ class SparseConstantSignature(tuple):
return hash_from_sparse(d) return hash_from_sparse(d)
class SparseConstant(TensorConstant, _sparse_py_operators): class SparseConstant(SparseVariable, TensorConstant):
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
def signature(self): def signature(self):
......
...@@ -3,11 +3,11 @@ import copy ...@@ -3,11 +3,11 @@ import copy
import scipy.sparse import scipy.sparse
from pytensor.compile import shared_constructor from pytensor.compile import shared_constructor
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators from pytensor.sparse.basic import SparseTensorType, SparseVariable
from pytensor.tensor.sharedvar import TensorSharedVariable from pytensor.tensor.sharedvar import TensorSharedVariable
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators): class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable):
@property @property
def format(self): def format(self):
return self.type.format return self.type.format
......
...@@ -6,7 +6,7 @@ from pytensor.compile import SharedVariable, shared_constructor ...@@ -6,7 +6,7 @@ from pytensor.compile import SharedVariable, shared_constructor
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import _tensor_py_operators from pytensor.tensor.variable import TensorVariable
def __getattr__(name): def __getattr__(name):
...@@ -31,7 +31,7 @@ def load_shared_variable(val): ...@@ -31,7 +31,7 @@ def load_shared_variable(val):
return tensor_constructor(val) return tensor_constructor(val)
class TensorSharedVariable(_tensor_py_operators, SharedVariable): class TensorSharedVariable(SharedVariable, TensorVariable):
def zero(self, borrow: bool = False): def zero(self, borrow: bool = False):
r"""Set the values of a shared variable to 0. r"""Set the values of a shared variable to 0.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论