提交 461832c5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Clean up docstrings and errors relating to SharedVariable

上级 b132036d
......@@ -13,7 +13,7 @@
.. class:: SharedVariable
Variable with Storage that is shared between functions that it appears in.
Variable with storage that is shared between the compiled functions that it appears in.
These variables are meant to be created by registered *shared constructors*
(see :func:`shared_constructor`).
......@@ -68,7 +68,6 @@
A container to use for this SharedVariable when it is an implicit function parameter.
:type: class:`Container`
.. autofunction:: shared
......@@ -76,10 +75,10 @@
Append `ctor` to the list of shared constructors (see :func:`shared`).
Each registered constructor ``ctor`` will be called like this:
Each registered constructor `ctor` will be called like this:
.. code-block:: python
ctor(value, name=name, strict=strict, **kwargs)
If it do not support given value, it must raise a TypeError.
If it do not support given value, it must raise a `TypeError`.
......@@ -78,10 +78,12 @@ def rebuild_collect_shared(
shared_inputs = []
def clone_v_get_shared_updates(v, copy_inputs_over):
"""
Clones a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared inputs,
and their default_update (if applicable) to update_d and update_expr.
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
Also, it appends all `SharedVariable`\s met along the way to
`shared_inputs` and their corresponding
`SharedVariable.default_update`\s (when applicable) to `update_d` and
`update_expr`.
"""
# this co-recurses with clone_a
......@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs(
givens = []
if not isinstance(params, (list, tuple)):
raise Exception("in pfunc() the first argument must be a list or " "a tuple")
raise TypeError("The `params` argument must be a list or a tuple")
if not isinstance(no_default_updates, bool) and not isinstance(
no_default_updates, list
):
raise TypeError("no_default_update should be either a boolean or " "a list")
raise TypeError("The `no_default_update` argument must be a boolean or list")
if len(updates) > 0 and any(
isinstance(v, Variable) for v in iter_over_pairs(updates)
if len(updates) > 0 and not all(
isinstance(pair, (tuple, list))
and len(pair) == 2
and isinstance(pair[0], Variable)
for pair in iter_over_pairs(updates)
):
raise ValueError(
"The updates parameter must be an OrderedDict/dict or a list of "
"lists/tuples with 2 elements"
raise TypeError(
"The `updates` parameter must be an ordered mapping or a list of pairs"
)
# transform params into pytensor.compile.In objects.
# Transform params into pytensor.compile.In objects.
inputs = [
_pfunc_param_to_in(p, allow_downcast=allow_input_downcast) for p in params
]
......
"""
Provide a simple user friendly API to PyTensor-managed memory.
"""
"""Provide a simple user friendly API to PyTensor-managed memory."""
import copy
from contextlib import contextmanager
......@@ -30,52 +27,14 @@ def collect_new_shareds():
class SharedVariable(Variable):
"""
Variable that is (defaults to being) shared between functions that
it appears in.
Parameters
----------
name : str
The name for this variable (see `Variable`).
type : str
The type for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
True : assignments to .value will not be cast or copied, so they must
have the correct type.
allow_downcast
Only applies if `strict` is False.
True : allow assigned value to lose precision when cast during
assignment.
False : never allow precision loss.
None : only allow downcasting of a Python float to a scalar floatX.
container
The container to use for this variable. Illegal to pass this as well as
a value.
Notes
-----
For more user-friendly constructor, see `shared`.
"""Variable that is shared between compiled functions."""
"""
# Container object
container = None
container: Optional[Container] = None
"""
A container to use for this SharedVariable when it is an implicit
function parameter.
:type: `Container`
"""
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super().__init__(type=type, name=name, owner=None, index=None)
......@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False):
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""Return a SharedVariable Variable, initialized with a copy or
reference of `value`.
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
This function iterates over constructor functions to find a
suitable SharedVariable subclass. The suitable one is the first
suitable `SharedVariable` subclass. The suitable one is the first
constructor that accept the given value. See the documentation of
:func:`shared_constructor` for the definition of a constructor
function.
This function is meant as a convenient default. If you want to use a
specific shared variable constructor, consider calling it directly.
``pytensor.shared`` is a shortcut to this function.
.. attribute:: constructors
specific constructor, consider calling it directly.
A list of shared variable constructors that will be tried in reverse
order.
`pytensor.shared` is a shortcut to this function.
Notes
-----
By passing kwargs, you effectively limit the set of potential constructors
to those that can accept those kwargs.
Some shared variable have ``borrow`` as extra kwargs.
Some shared variable have `borrow` as a kwarg.
Some shared variable have ``broadcastable`` as extra kwargs. As shared
`SharedVariable`\s of `TensorType` have `broadcastable` as a kwarg. As shared
variable shapes can change, all dimensions default to not being
broadcastable, even if ``value`` has a shape of 1 along some dimension.
This parameter allows you to create for example a `row` or `column` 2d
tensor.
broadcastable, even if `value` has a shape of 1 along some dimension.
This parameter allows one to create for example a row or column tensor.
"""
......
......@@ -22,7 +22,7 @@ class RandomGeneratorSharedVariable(SharedVariable):
def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False
):
r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`."""
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type
......
......@@ -41,8 +41,7 @@ def tensor_constructor(
target="cpu",
broadcastable=None,
):
"""
SharedVariable Constructor for TensorType.
r"""`SharedVariable` constructor for `TensorType`\s.
Notes
-----
......@@ -64,9 +63,8 @@ def tensor_constructor(
if not isinstance(value, np.ndarray):
raise TypeError()
# if no shape is given, then the default is to assume that
# the value might be resized in any dimension in the future.
#
# If no shape is given, then the default is to assume that the value might
# be resized in any dimension in the future.
if shape is None:
shape = (None,) * len(value.shape)
type = TensorType(value.dtype, shape=shape)
......@@ -79,13 +77,6 @@ def tensor_constructor(
)
# TensorSharedVariable brings in the tensor operators, is not ideal, but works
# as long as we don't do purely scalar-scalar operations
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
#
# N.B. THERE IS ANOTHER CLASS CALLED ScalarSharedVariable in the
# pytensor.scalar.sharedvar file. It is not registered as a shared_constructor,
# this one is.
class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass
......@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
def scalar_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
):
"""
SharedVariable constructor for scalar values. Default: int64 or float64.
"""`SharedVariable` constructor for scalar values.
Default: int64 or float64.
Notes
-----
......
......@@ -36,6 +36,22 @@ def data_of(s):
class TestPfunc:
def test_errors(self):
a = lscalar()
b = shared(1)
with pytest.raises(TypeError):
pfunc({a}, a + b)
with pytest.raises(TypeError):
pfunc([a], a + b, no_default_updates=1)
with pytest.raises(TypeError):
pfunc([a], a + b, updates=[{b, a}])
with pytest.raises(TypeError):
pfunc([a], a + b, updates=[(1, b)])
def test_doc(self):
# Ensure the code given in pfunc.txt works as expected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论