提交 556816f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Refactor SharedVariable type and interface

上级 08eaf0e3
......@@ -3,7 +3,7 @@
import copy
from contextlib import contextmanager
from functools import singledispatch
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from pytensor.graph.basic import Variable
from pytensor.graph.utils import add_tag_trace
......@@ -11,6 +11,10 @@ from pytensor.link.basic import Container
from pytensor.link.c.type import generic
if TYPE_CHECKING:
from pytensor.graph.type import Type
__SHARED_CONTEXT__: Optional[List[Variable]] = None
......@@ -30,14 +34,39 @@ def collect_new_shareds():
class SharedVariable(Variable):
"""Variable that is shared between compiled functions."""
container: Optional[Container] = None
"""
A container to use for this SharedVariable when it is an implicit
function parameter.
"""
def __init__(
self,
type: "Type",
value,
strict: bool,
allow_downcast=None,
container: Optional[Container] = None,
name: Optional[str] = None,
):
r"""
Parameters
----------
type
The `Type` for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
``True`` means that values assigned to this variable will not be
cast or copied, so they must have the correct `Type`\s.
allow_downcast
Only applies if `strict` is ``False``.
``True`` means that the assigned value can lose precision when cast
during assignment. ``None`` means that only down-casting of a Python
float to a scalar ``floatX`` is allowed.
container
The container to use for this variable. Illegal to pass this as well as
a value.
name
The name for this variable (see `Variable`).
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super().__init__(type=type, name=name, owner=None, index=None)
"""
super().__init__(type=type, owner=None, index=None, name=name)
if container is not None:
self.container = container
......@@ -107,26 +136,6 @@ class SharedVariable(Variable):
def get_test_value(self):
return self.get_value(borrow=True, return_internal_type=True)
def zero(self, borrow=False):
"""
Set the values of a shared variable to 0.
Parameters
----------
borrow : bbol
True to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems
regarding to the aliased memory.
Changes done with this function will be visible to all functions using
this SharedVariable.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value
def clone(self, **kwargs):
name = kwargs.get("name", self.name)
cp = self.__class__(
......@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
return SharedVariable(
type=generic,
value=value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)
......@@ -2,13 +2,15 @@ import copy
import scipy.sparse
from pytensor.compile import SharedVariable, shared_constructor
from pytensor.compile import shared_constructor
from pytensor.sparse.basic import SparseTensorType, _sparse_py_operators
from pytensor.tensor.sharedvar import TensorSharedVariable
class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format)
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
@property
def format(self):
return self.type.format
@shared_constructor.register(scipy.sparse.spmatrix)
......@@ -24,5 +26,5 @@ def sparse_constructor(
value = copy.deepcopy(value)
return SparseTensorSharedVariable(
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name
)
......@@ -37,7 +37,7 @@ def randomgen_constructor(
return rng_sv_type(
type=rng_type,
value=value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)
......@@ -19,9 +19,25 @@ def load_shared_variable(val):
return tensor_constructor(val)
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
pass
def zero(self, borrow: bool = False):
r"""Set the values of a shared variable to 0.
Parameters
----------
borrow
``True`` to modify the value of a shared variable directly by using
its previous value. Potentially this can cause problems regarding
to the aliased memory.
Changes done with this function will be visible to all functions using
this `SharedVariable`.
"""
if borrow:
self.container.value[...] = 0
else:
self.container.value = 0 * self.container.value
@_get_vector_length.register(TensorSharedVariable)
......@@ -69,13 +85,13 @@ def tensor_constructor(
return TensorSharedVariable(
type=type,
value=np.array(value, copy=(not borrow)),
name=name,
strict=strict,
allow_downcast=allow_downcast,
name=name,
)
class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
class ScalarSharedVariable(TensorSharedVariable):
pass
......
import numpy as np
import scipy as sp
import pytensor
from pytensor.sparse.sharedvar import SparseTensorSharedVariable
def test_shared_basic():
x = pytensor.shared(
sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True
)
assert isinstance(x, SparseTensorSharedVariable)
assert x.format == "csr"
assert x.dtype == "float64"
import numpy as np
import pytest
import pytensor
from pytensor.compile import SharedVariable
sp = pytest.importorskip("scipy", minversion="0.7.0")
def test_shared_basic():
x = pytensor.shared(sp.sparse.csr_matrix(np.eye(100)), name="blah", borrow=True)
assert isinstance(x, SharedVariable)
......@@ -10,6 +10,7 @@ from pytensor.misc.may_share_memory import may_share_memory
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.shape import Shape_i, specify_shape
from pytensor.tensor.sharedvar import ScalarSharedVariable, TensorSharedVariable
from tests import unittest_tools as utt
......@@ -657,10 +658,31 @@ class TestSharedOptions:
pass
def test_tensor_shared_zero():
shared_val = np.array([1.0, 3.0], dtype=np.float32)
res = pytensor.shared(value=shared_val, borrow=True)
assert isinstance(res, TensorSharedVariable)
assert res.get_value(borrow=True) is shared_val
res.zero(borrow=True)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
res.set_value(shared_val, borrow=True)
res.zero(borrow=False)
new_shared_val = res.get_value(borrow=True)
assert new_shared_val is not shared_val
assert np.array_equal(new_shared_val, np.zeros((2,), dtype=np.float32))
def test_scalar_shared_options():
# Simple test to make sure we do not loose that fonctionality.
pytensor.shared(value=0.0, name="lk", borrow=True)
pytensor.shared(value=np.float32(0.0), name="lk", borrow=True)
res = pytensor.shared(value=np.float32(0.0), name="lk", borrow=True)
assert isinstance(res, ScalarSharedVariable)
assert res.type.dtype == "float32"
assert res.name == "lk"
assert res.type.shape == ()
def test_get_vector_length():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论