提交 08eaf0e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Use singledispatch to register SharedVariable type constructors

上级 461832c5
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager
from functools import singledispatch
from typing import List, Optional from typing import List, Optional
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
...@@ -157,14 +158,6 @@ class SharedVariable(Variable): ...@@ -157,14 +158,6 @@ class SharedVariable(Variable):
self._default_update = value self._default_update = value
def shared_constructor(ctor, remove=False):
if remove:
shared.constructors.remove(ctor)
else:
shared.constructors.append(ctor)
return ctor
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
r"""Create a `SharedVariable` initialized with a copy or reference of `value`. r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
...@@ -193,53 +186,26 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs): ...@@ -193,53 +186,26 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
""" """
try: if isinstance(value, Variable):
if isinstance(value, Variable): raise TypeError("Shared variable values can not be symbolic.")
raise TypeError(
"Shared variable constructor needs numeric "
"values and not symbolic variables."
)
for ctor in reversed(shared.constructors):
try:
var = ctor(
value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
**kwargs,
)
add_tag_trace(var)
return var
except TypeError:
continue
# This may happen when kwargs were supplied
# if kwargs were given, the generic_constructor won't be callable.
#
# This was done on purpose, the rationale being that if kwargs
# were supplied, the user didn't want them to be ignored.
try:
var = shared_constructor(
value,
name=name,
strict=strict,
allow_downcast=allow_downcast,
**kwargs,
)
add_tag_trace(var)
return var
except MemoryError as e: except MemoryError as e:
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",) e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
raise raise
raise TypeError(
"No suitable SharedVariable constructor could be found."
" Are you sure all kwargs are supported?"
" We do not support the parameter dtype or type."
f' value="{value}". parameters="{kwargs}"'
)
shared.constructors = []
@singledispatch
@shared_constructor def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kwargs):
def generic_constructor(value, name=None, strict=False, allow_downcast=None):
"""
SharedVariable Constructor.
"""
return SharedVariable( return SharedVariable(
type=generic, type=generic,
value=value, value=value,
......
...@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable): ...@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
@shared_constructor @shared_constructor.register(scipy.sparse.spmatrix)
def sparse_constructor( def sparse_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, format=None value, name=None, strict=False, allow_downcast=None, borrow=False, format=None
): ):
if not isinstance(value, scipy.sparse.spmatrix):
raise TypeError(
"Expected a sparse matrix in the sparse shared variable constructor. Received: ",
value.__class__,
)
if format is None: if format is None:
format = value.format format = value.format
type = SparseTensorType(format=format, dtype=value.dtype) type = SparseTensorType(format=format, dtype=value.dtype)
if not borrow: if not borrow:
value = copy.deepcopy(value) value = copy.deepcopy(value)
return SparseTensorSharedVariable( return SparseTensorSharedVariable(
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
) )
...@@ -18,7 +18,8 @@ class RandomGeneratorSharedVariable(SharedVariable): ...@@ -18,7 +18,8 @@ class RandomGeneratorSharedVariable(SharedVariable):
) )
@shared_constructor @shared_constructor.register(np.random.RandomState)
@shared_constructor.register(np.random.Generator)
def randomgen_constructor( def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False value, name=None, strict=False, allow_downcast=None, borrow=False
): ):
...@@ -29,8 +30,6 @@ def randomgen_constructor( ...@@ -29,8 +30,6 @@ def randomgen_constructor(
elif isinstance(value, np.random.Generator): elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type rng_type = random_generator_type
else:
raise TypeError()
if not borrow: if not borrow:
value = copy.deepcopy(value) value = copy.deepcopy(value)
......
import traceback
import warnings import warnings
import numpy as np import numpy as np
...@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var): ...@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
return len(var.get_value(borrow=True)) return len(var.get_value(borrow=True))
@shared_constructor @shared_constructor.register(np.ndarray)
def tensor_constructor( def tensor_constructor(
value, value,
name=None, name=None,
...@@ -60,14 +59,13 @@ def tensor_constructor( ...@@ -60,14 +59,13 @@ def tensor_constructor(
if target != "cpu": if target != "cpu":
raise TypeError("not for cpu") raise TypeError("not for cpu")
if not isinstance(value, np.ndarray):
raise TypeError()
# If no shape is given, then the default is to assume that the value might # If no shape is given, then the default is to assume that the value might
# be resized in any dimension in the future. # be resized in any dimension in the future.
if shape is None: if shape is None:
shape = (None,) * len(value.shape) shape = (None,) * value.ndim
type = TensorType(value.dtype, shape=shape) type = TensorType(value.dtype, shape=shape)
return TensorSharedVariable( return TensorSharedVariable(
type=type, type=type,
value=np.array(value, copy=(not borrow)), value=np.array(value, copy=(not borrow)),
...@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass pass
@shared_constructor @shared_constructor.register(np.number)
@shared_constructor.register(float)
@shared_constructor.register(int)
@shared_constructor.register(complex)
def scalar_constructor( def scalar_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu" value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
): ):
...@@ -101,28 +102,22 @@ def scalar_constructor( ...@@ -101,28 +102,22 @@ def scalar_constructor(
if target != "cpu": if target != "cpu":
raise TypeError("not for cpu") raise TypeError("not for cpu")
if not isinstance(value, (np.number, float, int, complex)):
raise TypeError()
try: try:
dtype = value.dtype dtype = value.dtype
except Exception: except AttributeError:
dtype = np.asarray(value).dtype dtype = np.asarray(value).dtype
dtype = str(dtype) dtype = str(dtype)
value = _asarray(value, dtype=dtype) value = _asarray(value, dtype=dtype)
tensor_type = TensorType(dtype=str(value.dtype), shape=[]) tensor_type = TensorType(dtype=str(value.dtype), shape=())
try: # Do not pass the dtype to asarray because we want this to fail if
# Do not pass the dtype to asarray because we want this to fail if # strict is True and the types do not match.
# strict is True and the types do not match. rval = ScalarSharedVariable(
rval = ScalarSharedVariable( type=tensor_type,
type=tensor_type, value=np.array(value, copy=True),
value=np.array(value, copy=True), name=name,
name=name, strict=strict,
strict=strict, allow_downcast=allow_downcast,
allow_downcast=allow_downcast, )
) return rval
return rval
except Exception:
traceback.print_exc()
raise
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论