提交 08eaf0e3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Ricardo Vieira

Use singledispatch to register SharedVariable type constructors

上级 461832c5
......@@ -2,6 +2,7 @@
import copy
from contextlib import contextmanager
from functools import singledispatch
from typing import List, Optional
from pytensor.graph.basic import Variable
......@@ -157,14 +158,6 @@ class SharedVariable(Variable):
self._default_update = value
def shared_constructor(ctor, remove=False):
if remove:
shared.constructors.remove(ctor)
else:
shared.constructors.append(ctor)
return ctor
def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
......@@ -193,16 +186,11 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""
try:
if isinstance(value, Variable):
raise TypeError(
"Shared variable constructor needs numeric "
"values and not symbolic variables."
)
raise TypeError("Shared variable values can not be symbolic.")
for ctor in reversed(shared.constructors):
try:
var = ctor(
var = shared_constructor(
value,
name=name,
strict=strict,
......@@ -211,35 +199,13 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
)
add_tag_trace(var)
return var
except TypeError:
continue
# This may happen when kwargs were supplied
# if kwargs were given, the generic_constructor won't be callable.
#
# This was done on purpose, the rationale being that if kwargs
# were supplied, the user didn't want them to be ignored.
except MemoryError as e:
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
raise
raise TypeError(
"No suitable SharedVariable constructor could be found."
" Are you sure all kwargs are supported?"
" We do not support the parameter dtype or type."
f' value="{value}". parameters="{kwargs}"'
)
shared.constructors = []
@shared_constructor
def generic_constructor(value, name=None, strict=False, allow_downcast=None):
"""
SharedVariable Constructor.
"""
@singledispatch
def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kwargs):
return SharedVariable(
type=generic,
value=value,
......
......@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
format = property(lambda self: self.type.format)
@shared_constructor
@shared_constructor.register(scipy.sparse.spmatrix)
def sparse_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, format=None
):
if not isinstance(value, scipy.sparse.spmatrix):
raise TypeError(
"Expected a sparse matrix in the sparse shared variable constructor. Received: ",
value.__class__,
)
if format is None:
format = value.format
type = SparseTensorType(format=format, dtype=value.dtype)
if not borrow:
value = copy.deepcopy(value)
return SparseTensorSharedVariable(
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
)
......@@ -18,7 +18,8 @@ class RandomGeneratorSharedVariable(SharedVariable):
)
@shared_constructor
@shared_constructor.register(np.random.RandomState)
@shared_constructor.register(np.random.Generator)
def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False
):
......@@ -29,8 +30,6 @@ def randomgen_constructor(
elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
else:
raise TypeError()
if not borrow:
value = copy.deepcopy(value)
......
import traceback
import warnings
import numpy as np
......@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
return len(var.get_value(borrow=True))
@shared_constructor
@shared_constructor.register(np.ndarray)
def tensor_constructor(
value,
name=None,
......@@ -60,14 +59,13 @@ def tensor_constructor(
if target != "cpu":
raise TypeError("not for cpu")
if not isinstance(value, np.ndarray):
raise TypeError()
# If no shape is given, then the default is to assume that the value might
# be resized in any dimension in the future.
if shape is None:
shape = (None,) * len(value.shape)
shape = (None,) * value.ndim
type = TensorType(value.dtype, shape=shape)
return TensorSharedVariable(
type=type,
value=np.array(value, copy=(not borrow)),
......@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass
@shared_constructor
@shared_constructor.register(np.number)
@shared_constructor.register(float)
@shared_constructor.register(int)
@shared_constructor.register(complex)
def scalar_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False, target="cpu"
):
......@@ -101,18 +102,15 @@ def scalar_constructor(
if target != "cpu":
raise TypeError("not for cpu")
if not isinstance(value, (np.number, float, int, complex)):
raise TypeError()
try:
dtype = value.dtype
except Exception:
except AttributeError:
dtype = np.asarray(value).dtype
dtype = str(dtype)
value = _asarray(value, dtype=dtype)
tensor_type = TensorType(dtype=str(value.dtype), shape=[])
tensor_type = TensorType(dtype=str(value.dtype), shape=())
try:
# Do not pass the dtype to asarray because we want this to fail if
# strict is True and the types do not match.
rval = ScalarSharedVariable(
......@@ -123,6 +121,3 @@ def scalar_constructor(
allow_downcast=allow_downcast,
)
return rval
except Exception:
traceback.print_exc()
raise
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论