提交 461832c5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Clean up docstrings and errors relating to SharedVariable

上级 b132036d
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
.. class:: SharedVariable .. class:: SharedVariable
Variable with Storage that is shared between functions that it appears in. Variable with storage that is shared between the compiled functions that it appears in.
These variables are meant to be created by registered *shared constructors* These variables are meant to be created by registered *shared constructors*
(see :func:`shared_constructor`). (see :func:`shared_constructor`).
...@@ -68,7 +68,6 @@ ...@@ -68,7 +68,6 @@
A container to use for this SharedVariable when it is an implicit function parameter. A container to use for this SharedVariable when it is an implicit function parameter.
:type: class:`Container`
.. autofunction:: shared .. autofunction:: shared
...@@ -76,10 +75,10 @@ ...@@ -76,10 +75,10 @@
Append `ctor` to the list of shared constructors (see :func:`shared`). Append `ctor` to the list of shared constructors (see :func:`shared`).
Each registered constructor ``ctor`` will be called like this: Each registered constructor `ctor` will be called like this:
.. code-block:: python .. code-block:: python
ctor(value, name=name, strict=strict, **kwargs) ctor(value, name=name, strict=strict, **kwargs)
If it do not support given value, it must raise a TypeError. If it do not support given value, it must raise a `TypeError`.
...@@ -78,10 +78,12 @@ def rebuild_collect_shared( ...@@ -78,10 +78,12 @@ def rebuild_collect_shared(
shared_inputs = [] shared_inputs = []
def clone_v_get_shared_updates(v, copy_inputs_over): def clone_v_get_shared_updates(v, copy_inputs_over):
""" r"""Clones a variable and its inputs recursively until all are in `clone_d`.
Clones a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared inputs, Also, it appends all `SharedVariable`\s met along the way to
and their default_update (if applicable) to update_d and update_expr. `shared_inputs` and their corresponding
`SharedVariable.default_update`\s (when applicable) to `update_d` and
`update_expr`.
""" """
# this co-recurses with clone_a # this co-recurses with clone_a
...@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs( ...@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs(
givens = [] givens = []
if not isinstance(params, (list, tuple)): if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or " "a tuple") raise TypeError("The `params` argument must be a list or a tuple")
if not isinstance(no_default_updates, bool) and not isinstance( if not isinstance(no_default_updates, bool) and not isinstance(
no_default_updates, list no_default_updates, list
): ):
raise TypeError("no_default_update should be either a boolean or " "a list") raise TypeError("The `no_default_update` argument must be a boolean or list")
if len(updates) > 0 and any( if len(updates) > 0 and not all(
isinstance(v, Variable) for v in iter_over_pairs(updates) isinstance(pair, (tuple, list))
and len(pair) == 2
and isinstance(pair[0], Variable)
for pair in iter_over_pairs(updates)
): ):
raise ValueError( raise TypeError(
"The updates parameter must be an OrderedDict/dict or a list of " "The `updates` parameter must be an ordered mapping or a list of pairs"
"lists/tuples with 2 elements"
) )
# transform params into pytensor.compile.In objects. # Transform params into pytensor.compile.In objects.
inputs = [ inputs = [
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params _pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
] ]
......
""" """Provide a simple user friendly API to PyTensor-managed memory."""
Provide a simple user friendly API to PyTensor-managed memory.
"""
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
...@@ -30,52 +27,14 @@ def collect_new_shareds(): ...@@ -30,52 +27,14 @@ def collect_new_shareds():
class SharedVariable(Variable): class SharedVariable(Variable):
""" """Variable that is shared between compiled functions."""
Variable that is (defaults to being) shared between functions that
it appears in.
Parameters
----------
name : str
The name for this variable (see `Variable`).
type : str
The type for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
True : assignments to .value will not be cast or copied, so they must
have the correct type.
allow_downcast
Only applies if `strict` is False.
True : allow assigned value to lose precision when cast during
assignment.
False : never allow precision loss.
None : only allow downcasting of a Python float to a scalar floatX.
container
The container to use for this variable. Illegal to pass this as well as
a value.
Notes
-----
For more user-friendly constructor, see `shared`.
"""
# Container object container: Optional[Container] = None
container = None
""" """
A container to use for this SharedVariable when it is an implicit A container to use for this SharedVariable when it is an implicit
function parameter. function parameter.
:type: `Container`
""" """
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=None, container=None): def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super().__init__(type=type, name=name, owner=None, index=None) super().__init__(type=type, name=name, owner=None, index=None)
...@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False): ...@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False):
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
reference of `value`.
This function iterates over constructor functions to find a This function iterates over constructor functions to find a
suitable SharedVariable subclass. The suitable one is the first suitable `SharedVariable` subclass. The suitable one is the first
constructor that accept the given value. See the documentation of constructor that accept the given value. See the documentation of
:func:`shared_constructor` for the definition of a constructor :func:`shared_constructor` for the definition of a constructor
function. function.
This function is meant as a convenient default. If you want to use a This function is meant as a convenient default. If you want to use a
specific shared variable constructor, consider calling it directly. specific constructor, consider calling it directly.
``pytensor.shared`` is a shortcut to this function.
.. attribute:: constructors
A list of shared variable constructors that will be tried in reverse `pytensor.shared` is a shortcut to this function.
order.
Notes Notes
----- -----
By passing kwargs, you effectively limit the set of potential constructors By passing kwargs, you effectively limit the set of potential constructors
to those that can accept those kwargs. to those that can accept those kwargs.
Some shared variable have ``borrow`` as extra kwargs. Some shared variable have `borrow` as a kwarg.
Some shared variable have ``broadcastable`` as extra kwargs. As shared `SharedVariable`\s of `TensorType` have `broadcastable` as a kwarg. As shared
variable shapes can change, all dimensions default to not being variable shapes can change, all dimensions default to not being
broadcastable, even if ``value`` has a shape of 1 along some dimension. broadcastable, even if `value` has a shape of 1 along some dimension.
This parameter allows you to create for example a `row` or `column` 2d This parameter allows one to create for example a row or column tensor.
tensor.
""" """
......
...@@ -22,7 +22,7 @@ class RandomGeneratorSharedVariable(SharedVariable): ...@@ -22,7 +22,7 @@ class RandomGeneratorSharedVariable(SharedVariable):
def randomgen_constructor( def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False value, name=None, strict=False, allow_downcast=None, borrow=False
): ):
r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`.""" r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState): if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type rng_type = random_state_type
......
...@@ -41,8 +41,7 @@ def tensor_constructor( ...@@ -41,8 +41,7 @@ def tensor_constructor(
target="cpu", target="cpu",
broadcastable=None, broadcastable=None,
): ):
""" r"""`SharedVariable` constructor for `TensorType`\s.
SharedVariable Constructor for TensorType.
Notes Notes
----- -----
...@@ -64,9 +63,8 @@ def tensor_constructor( ...@@ -64,9 +63,8 @@ def tensor_constructor(
if not isinstance(value, np.ndarray): if not isinstance(value, np.ndarray):
raise TypeError() raise TypeError()
# if no shape is given, then the default is to assume that # If no shape is given, then the default is to assume that the value might
# the value might be resized in any dimension in the future. # be resized in any dimension in the future.
#
if shape is None: if shape is None:
shape = (None,) * len(value.shape) shape = (None,) * len(value.shape)
type = TensorType(value.dtype, shape=shape) type = TensorType(value.dtype, shape=shape)
...@@ -79,13 +77,6 @@ def tensor_constructor( ...@@ -79,13 +77,6 @@ def tensor_constructor(
) )
# TensorSharedVariable brings in the tensor operators, is not ideal, but works
# as long as we don't do purely scalar-scalar operations
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
#
# N.B. THERE IS ANOTHER CLASS CALLED ScalarSharedVariable in the
# pytensor.scalar.sharedvar file. It is not registered as a shared_constructor,
# this one is.
class ScalarSharedVariable(_tensor_py_operators, SharedVariable): class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
...@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
def scalar_constructor( def scalar_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu" value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
): ):
""" """`SharedVariable` constructor for scalar values.
SharedVariable constructor for scalar values. Default: int64 or float64.
Default: int64 or float64.
Notes Notes
----- -----
......
...@@ -36,6 +36,22 @@ def data_of(s): ...@@ -36,6 +36,22 @@ def data_of(s):
class TestPfunc: class TestPfunc:
def test_errors(self):
a = lscalar()
b = shared(1)
with pytest.raises(TypeError):
pfunc({a}, a + b)
with pytest.raises(TypeError):
pfunc([a], a + b, no_default_updates=1)
with pytest.raises(TypeError):
pfunc([a], a + b, updates=[{b, a}])
with pytest.raises(TypeError):
pfunc([a], a + b, updates=[(1, b)])
def test_doc(self): def test_doc(self):
# Ensure the code given in pfunc.txt works as expected # Ensure the code given in pfunc.txt works as expected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论