提交 556816f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor SharedVariable type and interface

上级 08eaf0e3
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.utils import add_tag_trace from pytensor.graph.utils import add_tag_trace
...@@ -11,6 +11,10 @@ from pytensor.link.basic import Container ...@@ -11,6 +11,10 @@ from pytensor.link.basic import Container
from pytensor.link.c.type import generic from pytensor.link.c.type import generic
if TYPE_CHECKING:
from pytensor.graph.type import Type
__SHARED_CONTEXT__: Optional[List[Variable]] = None __SHARED_CONTEXT__: Optional[List[Variable]] = None
...@@ -30,14 +34,39 @@ def collect_new_shareds(): ...@@ -30,14 +34,39 @@ def collect_new_shareds():
class SharedVariable(Variable): class SharedVariable(Variable):
"""Variable that is shared between compiled functions.""" """Variable that is shared between compiled functions."""
container: Optional[Container] = None def __init__(
""" self,
A container to use for this SharedVariable when it is an implicit type: "Type",
function parameter. value,
""" strict: bool,
allow_downcast=None,
container: Optional[Container] = None,
name: Optional[str] = None,
):
r"""
Parameters
----------
type
The `Type` for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
``True`` means that values assigned to this variable will not be
cast or copied, so they must have the correct `Type`\s.
allow_downcast
Only applies if `strict` is ``False``.
``True`` means that the assigned value can lose precision when cast
during assignment. ``None`` means that only down-casting of a Python
float to a scalar ``floatX`` is allowed.
container
The container to use for this variable. Illegal to pass this as well as
a value.
name
The name for this variable (see `Variable`).
def __init__(self, name, type, value, strict, allow_downcast=None, container=None): """
super().__init__(type=type, name=name, owner=None, index=None) super().__init__(type=type, owner=None, index=None, name=name)
if container is not None: if container is not None:
self.container = container self.container = container
...@@ -107,26 +136,6 @@ class SharedVariable(Variable): ...@@ -107,26 +136,6 @@ class SharedVariable(Variable):
def get_test_value(self): def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True) return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
Parameters
----------
borrow : bbol
True to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems
regarding to the aliased memory.
Changes done with this function will be visible to all functions using
this SharedVariable.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value
def clone(self, **kwargs): def clone(self, **kwargs):
name = kwargs.get("name", self.name) name = kwargs.get("name", self.name)
cp = self.__class__( cp = self.__class__(
...@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw ...@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
return SharedVariable( return SharedVariable(
type=generic, type=generic,
value=value, value=value,
name=name,
strict=strict, strict=strict,
allow_downcast=allow_downcast, allow_downcast=allow_downcast,
name=name,
) )
...@@ -2,13 +2,15 @@ import copy ...@@ -2,13 +2,15 @@ import copy
import scipy.sparse import scipy.sparse
from pytensor.compile import SharedVariable, shared_constructor from pytensor.compile import shared_constructor
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators
from pytensor.tensor.sharedvar import TensorSharedVariable
class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable): class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype) @property
format = property(lambda self: self.type.format) def format(self):
return self.type.format
@shared_constructor.register(scipy.sparse.spmatrix) @shared_constructor.register(scipy.sparse.spmatrix)
...@@ -24,5 +26,5 @@ def sparse_constructor( ...@@ -24,5 +26,5 @@ def sparse_constructor(
value = copy.deepcopy(value) value = copy.deepcopy(value)
return SparseTensorSharedVariable( return SparseTensorSharedVariable(
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name
) )
...@@ -37,7 +37,7 @@ def randomgen_constructor( ...@@ -37,7 +37,7 @@ def randomgen_constructor(
return rng_sv_type( return rng_sv_type(
type=rng_type, type=rng_type,
value=value, value=value,
name=name,
strict=strict, strict=strict,
allow_downcast=allow_downcast, allow_downcast=allow_downcast,
name=name,
) )
...@@ -19,9 +19,25 @@ def load_shared_variable(val): ...@@ -19,9 +19,25 @@ def load_shared_variable(val):
return tensor_constructor(val) return tensor_constructor(val)
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
class TensorSharedVariable(_tensor_py_operators, SharedVariable): class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass def zero(self, borrow: bool = False):
r"""Set the values of a shared variable to 0.
Parameters
----------
borrow
``True`` to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems regarding
to the aliased memory.
Changes done with this function will be visible to all functions using
this `SharedVariable`.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value
@_get_vector_length.register(TensorSharedVariable) @_get_vector_length.register(TensorSharedVariable)
...@@ -69,13 +85,13 @@ def tensor_constructor( ...@@ -69,13 +85,13 @@ def tensor_constructor(
return TensorSharedVariable( return TensorSharedVariable(
type=type, type=type,
value=np.array(value, copy=(not borrow)), value=np.array(value, copy=(not borrow)),
name=name,
strict=strict, strict=strict,
allow_downcast=allow_downcast, allow_downcast=allow_downcast,
name=name,
) )
class ScalarSharedVariable(_tensor_py_operators, SharedVariable): class ScalarSharedVariable(TensorSharedVariable):
pass pass
......
import numpy as np
import scipy as sp
import pytensor
from pytensor.sparse.sharedvar import SparseTensorSharedVariable
def test_shared_basic():
x = pytensor.shared(
sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True
)
assert isinstance(x, SparseTensorSharedVariable)
assert x.format == "csr"
assert x.dtype == "float64"
import numpy as np
import pytest
import pytensor
from pytensor.compile import SharedVariable
sp = pytest.importorskip("scipy", minversion="0.7.0")
def test_shared_basic():
x = pytensor.shared(sp.sparse.csr_matrix(np.eye(100)), name="blah", borrow=True)
assert isinstance(x, SharedVariable)
...@@ -10,6 +10,7 @@ from pytensor.misc.may_share_memory import may_share_memory ...@@ -10,6 +10,7 @@ from pytensor.misc.may_share_memory import may_share_memory
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import MakeVector from pytensor.tensor.basic import MakeVector
from pytensor.tensor.shape import Shape_i, specify_shape from pytensor.tensor.shape import Shape_i, specify_shape
from pytensor.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -657,10 +658,31 @@ class TestSharedOptions: ...@@ -657,10 +658,31 @@ class TestSharedOptions:
pass pass
def test_tensor_shared_zero():
shared_val = np.array([1.0, 3.0], dtype=np.float32)
res = pytensor.shared(value=shared_val, borrow=True)
assert isinstance(res, TensorSharedVariable)
assert res.get_value(borrow=True) is shared_val
res.zero(borrow=True)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
res.set_value(shared_val, borrow=True)
res.zero(borrow=False)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is not shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
def test_scalar_shared_options(): def test_scalar_shared_options():
# Simple test to make sure we do not loose that fonctionality. res = pytensor.shared(value=np.float32(0.0), name="lk", borrow=True)
pytensor.shared(value=0.0, name="lk", borrow=True) assert isinstance(res, ScalarSharedVariable)
pytensor.shared(value=np.float32(0.0), name="lk", borrow=True) assert res.type.dtype == "float32"
assert res.name == "lk"
assert res.type.shape == ()
def test_get_vector_length(): def test_get_vector_length():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论