提交 b132036d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Make SharedVariable.default_update a persistent property

`Type` checks are also performed when values are assigned to the property.
上级 8968a387
...@@ -103,7 +103,7 @@ def rebuild_collect_shared( ...@@ -103,7 +103,7 @@ def rebuild_collect_shared(
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, "default_update"): if v.default_update is not None:
# Check that v should not be excluded from the default # Check that v should not be excluded from the default
# updates list # updates list
if no_default_updates is False or ( if no_default_updates is False or (
......
...@@ -7,8 +7,6 @@ import copy ...@@ -7,8 +7,6 @@ import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
import numpy as np
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.utils import add_tag_trace from pytensor.graph.utils import add_tag_trace
from pytensor.link.basic import Container from pytensor.link.basic import Container
...@@ -103,6 +101,8 @@ class SharedVariable(Variable): ...@@ -103,6 +101,8 @@ class SharedVariable(Variable):
if isinstance(__SHARED_CONTEXT__, list): if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self) __SHARED_CONTEXT__.append(self)
self._default_update: Optional[Variable] = None
def get_value(self, borrow=False, return_internal_type=False): def get_value(self, borrow=False, return_internal_type=False):
""" """
Get the non-symbolic value associated with this SharedVariable. Get the non-symbolic value associated with this SharedVariable.
...@@ -179,47 +179,23 @@ class SharedVariable(Variable): ...@@ -179,47 +179,23 @@ class SharedVariable(Variable):
cp.tag = copy.copy(self.tag) cp.tag = copy.copy(self.tag)
return cp return cp
def __getitem__(self, *args): @property
# __getitem__ is not available for generic SharedVariable objects. def default_update(self) -> Optional[Variable]:
# We raise a TypeError like Python would do if __getitem__ was not """A default update expression for this `Variable`.
# implemented at all, but with a more explicit error message to help
# PyTensor users figure out the root of the problem more easily.
value = self.get_value(borrow=True)
if isinstance(value, np.ndarray):
# Array probably had an unknown dtype.
msg = (
f"a Numpy array with dtype: '{value.dtype}'. This data type is not "
"currently recognized by PyTensor tensors: please cast "
"your data into a supported numeric type if you need "
"PyTensor tensor functionalities."
)
else:
msg = (
f"an object of type: {type(value)}. Did you forget to cast it into "
"a Numpy array before calling pytensor.shared()?"
)
raise TypeError(
"The generic 'SharedVariable' object is not subscriptable. "
f"This shared variable contains {msg}"
)
def _value_get(self):
raise Exception(
"sharedvar.value does not exist anymore. Use "
"sharedvar.get_value() or sharedvar.set_value()"
" instead."
)
def _value_set(self, new_value): If this value is non-``None``, its value will be used as the `update`
raise Exception( (see `pytensor.function`) for this `Variable` when no updates are
"sharedvar.value does not exist anymore. Use " provided through `pytensor.function` and `no_default_updates` isn't
"sharedvar.get_value() or sharedvar.set_value()" enabled.
" instead." """
) return self._default_update
# We keep this just to raise an error @default_update.setter
value = property(_value_get, _value_set) def default_update(self, value):
if value is not None:
self._default_update = self.type.filter_variable(value, allow_convert=True)
else:
self._default_update = value
def shared_constructor(ctor, remove=False): def shared_constructor(ctor, remove=False):
......
...@@ -996,8 +996,8 @@ def scan( ...@@ -996,8 +996,8 @@ def scan(
# We also don't want to remove a default update that applies to # We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove # the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables. # default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"): if is_local and input.variable.default_update is not None:
del input.variable.default_update input.variable.default_update = None
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
......
...@@ -432,7 +432,8 @@ class TestPfunc: ...@@ -432,7 +432,8 @@ class TestPfunc:
f() f()
assert x.get_value() == 1 assert x.get_value() == 1
del x.default_update x.default_update = None
f() f()
assert x.get_value() == 2 assert x.get_value() == 2
......
...@@ -282,10 +282,10 @@ class TestScan: ...@@ -282,10 +282,10 @@ class TestScan:
n_steps=4, n_steps=4,
) )
assert not hasattr(inner_rng, "default_update") assert inner_rng is None
assert hasattr(inner_inner_rng, "default_update") assert inner_inner_rng.default_update is not None
assert hasattr(y, "default_update") assert y.default_update is not None
assert hasattr(z_rng, "default_update") assert z_rng.default_update is not None
out_fn = function([], out, mode=Mode(optimizer=None)) out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn() res, z_res = out_fn()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论