提交 b132036d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Make SharedVariable.default_update a persistent property

`Type` checks are also performed when values are assigned to the property.
上级 8968a387
......@@ -103,7 +103,7 @@ def rebuild_collect_shared(
elif isinstance(v, SharedVariable):
if v not in shared_inputs:
shared_inputs.append(v)
if hasattr(v, "default_update"):
if v.default_update is not None:
# Check that v should not be excluded from the default
# updates list
if no_default_updates is False or (
......
......@@ -7,8 +7,6 @@ import copy
from contextlib import contextmanager
from typing import List, Optional
import numpy as np
from pytensor.graph.basic import Variable
from pytensor.graph.utils import add_tag_trace
from pytensor.link.basic import Container
......@@ -103,6 +101,8 @@ class SharedVariable(Variable):
if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self)
self._default_update: Optional[Variable] = None
def get_value(self, borrow=False, return_internal_type=False):
"""
Get the non-symbolic value associated with this SharedVariable.
......@@ -179,47 +179,23 @@ class SharedVariable(Variable):
cp.tag = copy.copy(self.tag)
return cp
def __getitem__(self, *args):
# __getitem__ is not available for generic SharedVariable objects.
# We raise a TypeError like Python would do if __getitem__ was not
# implemented at all, but with a more explicit error message to help
# PyTensor users figure out the root of the problem more easily.
value = self.get_value(borrow=True)
if isinstance(value, np.ndarray):
# Array probably had an unknown dtype.
msg = (
f"a Numpy array with dtype: '{value.dtype}'. This data type is not "
"currently recognized by PyTensor tensors: please cast "
"your data into a supported numeric type if you need "
"PyTensor tensor functionalities."
)
else:
msg = (
f"an object of type: {type(value)}. Did you forget to cast it into "
"a Numpy array before calling pytensor.shared()?"
)
raise TypeError(
"The generic 'SharedVariable' object is not subscriptable. "
f"This shared variable contains {msg}"
)
def _value_get(self):
raise Exception(
"sharedvar.value does not exist anymore. Use "
"sharedvar.get_value() or sharedvar.set_value()"
" instead."
)
@property
def default_update(self) -> Optional[Variable]:
"""A default update expression for this `Variable`.
def _value_set(self, new_value):
raise Exception(
"sharedvar.value does not exist anymore. Use "
"sharedvar.get_value() or sharedvar.set_value()"
" instead."
)
If this value is non-``None``, its value will be used as the `update`
(see `pytensor.function`) for this `Variable` when no updates are
provided through `pytensor.function` and `no_default_updates` isn't
enabled.
"""
return self._default_update
# We keep this just to raise an error
value = property(_value_get, _value_set)
@default_update.setter
def default_update(self, value):
if value is not None:
self._default_update = self.type.filter_variable(value, allow_convert=True)
else:
self._default_update = value
def shared_constructor(ctor, remove=False):
......
......@@ -996,8 +996,8 @@ def scan(
# We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"):
del input.variable.default_update
if is_local and input.variable.default_update is not None:
input.variable.default_update = None
new_var = safe_new(input.variable)
......
......@@ -432,7 +432,8 @@ class TestPfunc:
f()
assert x.get_value() == 1
del x.default_update
x.default_update = None
f()
assert x.get_value() == 2
......
......@@ -282,10 +282,10 @@ class TestScan:
n_steps=4,
)
assert not hasattr(inner_rng, "default_update")
assert hasattr(inner_inner_rng, "default_update")
assert hasattr(y, "default_update")
assert hasattr(z_rng, "default_update")
assert inner_rng is None
assert inner_inner_rng.default_update is not None
assert y.default_update is not None
assert z_rng.default_update is not None
out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论