提交 b132036d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Make SharedVariable.default_update a persistent property

`Type` checks are also performed when values are assigned to the property.
上级 8968a387
...@@ -103,7 +103,7 @@ def rebuild_collect_shared( ...@@ -103,7 +103,7 @@ def rebuild_collect_shared(
elif isinstance(v, SharedVariable): elif isinstance(v, SharedVariable):
if v not in shared_inputs: if v not in shared_inputs:
shared_inputs.append(v) shared_inputs.append(v)
if hasattr(v, "default_update"): if v.default_update is not None:
# Check that v should not be excluded from the default # Check that v should not be excluded from the default
# updates list # updates list
if no_default_updates is False or ( if no_default_updates is False or (
......
...@@ -7,8 +7,6 @@ import copy ...@@ -7,8 +7,6 @@ import copy
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
import numpy as np
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.graph.utils import add_tag_trace from pytensor.graph.utils import add_tag_trace
from pytensor.link.basic import Container from pytensor.link.basic import Container
...@@ -103,6 +101,8 @@ class SharedVariable(Variable): ...@@ -103,6 +101,8 @@ class SharedVariable(Variable):
if isinstance(__SHARED_CONTEXT__, list): if isinstance(__SHARED_CONTEXT__, list):
__SHARED_CONTEXT__.append(self) __SHARED_CONTEXT__.append(self)
self._default_update: Optional[Variable] = None
def get_value(self, borrow=False, return_internal_type=False): def get_value(self, borrow=False, return_internal_type=False):
""" """
Get the non-symbolic value associated with this SharedVariable. Get the non-symbolic value associated with this SharedVariable.
...@@ -179,47 +179,23 @@ class SharedVariable(Variable): ...@@ -179,47 +179,23 @@ class SharedVariable(Variable):
cp.tag = copy.copy(self.tag) cp.tag = copy.copy(self.tag)
return cp return cp
def __getitem__(self, *args): @property
# __getitem__ is not available for generic SharedVariable objects. def default_update(self) -> Optional[Variable]:
# We raise a TypeError like Python would do if __getitem__ was not """A default update expression for this `Variable`.
# implemented at all, but with a more explicit error message to help
# PyTensor users figure out the root of the problem more easily.
value = self.get_value(borrow=True)
if isinstance(value, np.ndarray):
# Array probably had an unknown dtype.
msg = (
f"a Numpy array with dtype: '{value.dtype}'. This data type is not "
"currently recognized by PyTensor tensors: please cast "
"your data into a supported numeric type if you need "
"PyTensor tensor functionalities."
)
else:
msg = (
f"an object of type: {type(value)}. Did you forget to cast it into "
"a Numpy array before calling pytensor.shared()?"
)
raise TypeError( If this value is non-``None``, its value will be used as the `update`
"The generic 'SharedVariable' object is not subscriptable. " (see `pytensor.function`) for this `Variable` when no updates are
f"This shared variable contains {msg}" provided through `pytensor.function` and `no_default_updates` isn't
) enabled.
"""
def _value_get(self): return self._default_update
raise Exception(
"sharedvar.value does not exist anymore. Use "
"sharedvar.get_value() or sharedvar.set_value()"
" instead."
)
def _value_set(self, new_value):
raise Exception(
"sharedvar.value does not exist anymore. Use "
"sharedvar.get_value() or sharedvar.set_value()"
" instead."
)
# We keep this just to raise an error @default_update.setter
value = property(_value_get, _value_set) def default_update(self, value):
if value is not None:
self._default_update = self.type.filter_variable(value, allow_convert=True)
else:
self._default_update = value
def shared_constructor(ctor, remove=False): def shared_constructor(ctor, remove=False):
......
...@@ -996,8 +996,8 @@ def scan( ...@@ -996,8 +996,8 @@ def scan(
# We also don't want to remove a default update that applies to # We also don't want to remove a default update that applies to
# the scope/context containing this `Scan`, so we only remove # the scope/context containing this `Scan`, so we only remove
# default updates on "local" variables. # default updates on "local" variables.
if is_local and hasattr(input.variable, "default_update"): if is_local and input.variable.default_update is not None:
del input.variable.default_update input.variable.default_update = None
new_var = safe_new(input.variable) new_var = safe_new(input.variable)
......
...@@ -432,7 +432,8 @@ class TestPfunc: ...@@ -432,7 +432,8 @@ class TestPfunc:
f() f()
assert x.get_value() == 1 assert x.get_value() == 1
del x.default_update x.default_update = None
f() f()
assert x.get_value() == 2 assert x.get_value() == 2
......
...@@ -282,10 +282,10 @@ class TestScan: ...@@ -282,10 +282,10 @@ class TestScan:
n_steps=4, n_steps=4,
) )
assert not hasattr(inner_rng, "default_update") assert inner_rng is None
assert hasattr(inner_inner_rng, "default_update") assert inner_inner_rng.default_update is not None
assert hasattr(y, "default_update") assert y.default_update is not None
assert hasattr(z_rng, "default_update") assert z_rng.default_update is not None
out_fn = function([], out, mode=Mode(optimizer=None)) out_fn = function([], out, mode=Mode(optimizer=None))
res, z_res = out_fn() res, z_res = out_fn()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论