提交 334c86fb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a dtype option to aesara.tensor.as_tensor_variable

上级 e3fabae3
......@@ -5,34 +5,37 @@ __docformat__ = "restructuredtext en"
import warnings
from functools import singledispatch
from typing import Callable, NoReturn, Optional
from typing import Any, Callable, NoReturn, Optional
def as_tensor_variable(
x, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Callable:
"""Convert `x` into the appropriate `TensorType`.
"""Convert `x` into the appropriate ``TensorType``.
This function is often used by `make_node` methods of `Op` subclasses to
turn ndarrays, numbers, `Scalar` instances, `Apply` instances and
`TensorType` instances into valid input list elements.
This function is often used by ``make_node`` methods of ``Op`` subclasses
to turn ndarrays, numbers, ``Scalar`` instances, ``Apply`` instances and
``TensorType`` instances into valid input list elements.
Parameters
----------
x : Apply or Variable or numpy.ndarray or number
This thing will be transformed into a `Variable` in a sensible way. An
ndarray argument will not be copied, but a list of numbers will be
copied to make an ndarray.
name : str or None
If a new `Variable` instance is created, it will be named with this
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
ndim : None or integer
Return a Variable with this many dimensions.
ndim
Return a ``Variable`` with this many dimensions.
dtype
The dtype to use for the resulting ``Variable``. If `x` is already
a ``Variable`` type, then the dtype will not be changed.
Raises
------
TypeError
If `x` cannot be converted to a TensorType Variable.
If `x` cannot be converted to a ``TensorType`` Variable.
"""
return _as_tensor_variable(x, name, ndim, **kwargs)
......@@ -42,7 +45,7 @@ def as_tensor_variable(
def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
) -> NoReturn:
raise NotImplementedError("")
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
import aesara.tensor.exceptions
......
......@@ -87,7 +87,7 @@ def __oplist_tag(thing, tag):
@_as_tensor_variable.register(Apply)
def _as_tensor_Apply(x, name, ndim):
def _as_tensor_Apply(x, name, ndim, **kwargs):
# use Apply's default output mechanism
if (x.op.default_output is None) and (len(x.outputs) != 1):
raise TypeError(
......@@ -97,17 +97,17 @@ def _as_tensor_Apply(x, name, ndim):
x = x.default_output()
return as_tensor_variable(x, name=name, ndim=ndim)
return as_tensor_variable(x, name=name, ndim=ndim, **kwargs)
@_as_tensor_variable.register(ScalarVariable)
@_as_tensor_variable.register(ScalarConstant)
def _as_tensor_Scalar(x, name, ndim):
return as_tensor_variable(tensor_from_scalar(x), name=name, ndim=ndim)
def _as_tensor_Scalar(x, name, ndim, **kwargs):
return as_tensor_variable(tensor_from_scalar(x), name=name, ndim=ndim, **kwargs)
@_as_tensor_variable.register(Variable)
def _as_tensor_Variable(x, name, ndim):
def _as_tensor_Variable(x, name, ndim, **kwargs):
if not isinstance(x.type, TensorType):
raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
......@@ -137,10 +137,10 @@ def _as_tensor_Variable(x, name, ndim):
@_as_tensor_variable.register(list)
@_as_tensor_variable.register(tuple)
def _as_tensor_Sequence(x, name, ndim):
def _as_tensor_Sequence(x, name, ndim, dtype=None, **kwargs):
if len(x) == 0:
return constant(x, name=name, ndim=ndim)
return constant(x, name=name, ndim=ndim, dtype=dtype)
# If a sequence has `Variable`s in it, then we want
# to customize the conversion to a tensor type.
......@@ -161,7 +161,8 @@ def _as_tensor_Sequence(x, name, ndim):
):
# In this instance, we have a sequence of constants with which we
# want to construct a vector, so we can use `MakeVector` directly.
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
if dtype is None:
dtype = aes.upcast(*[i.dtype for i in x if hasattr(i, "dtype")])
return MakeVector(dtype)(*x)
# In this case, we have at least one non-`Constant` term, so we
......@@ -169,19 +170,19 @@ def _as_tensor_Sequence(x, name, ndim):
# symbolically join terms.
return stack(x)
return constant(x, name=name, ndim=ndim)
return constant(x, name=name, ndim=ndim, dtype=dtype)
@_as_tensor_variable.register(np.bool_)
@_as_tensor_variable.register(np.number)
@_as_tensor_variable.register(Number)
@_as_tensor_variable.register(np.ndarray)
def _as_tensor_numbers(x, name, ndim):
return constant(x, name=name, ndim=ndim)
def _as_tensor_numbers(x, name, ndim, dtype=None, **kwargs):
return constant(x, name=name, ndim=ndim, dtype=dtype)
@_as_tensor_variable.register(bool)
def _as_tensor_bool(x, name, ndim):
def _as_tensor_bool(x, name, ndim, **kwargs):
raise TypeError(
"Cannot cast True or False as a tensor variable. Please use "
"np.array(True) or np.array(False) if you need these constants. "
......
......@@ -462,6 +462,26 @@ class TestAsTensorVariable:
ten = as_tensor_variable(np.array([True, False, False, True, True]))
assert ten.type.dtype == "bool"
def test_dtype(self):
res = as_tensor_variable([])
assert res.type.dtype == config.floatX
res = as_tensor_variable([], dtype="int64")
assert res.type.dtype == "int64"
res = as_tensor_variable(np.array([1], dtype="int32"), dtype="int64")
assert res.type.dtype == "int64"
res = as_tensor_variable(np.array([1.0], dtype=config.floatX), dtype="int64")
# TODO: This cross-type conversion probably shouldn't be the default.
assert res.type.dtype == "int64"
x = as_tensor_variable(np.array([1.0, 2.0], dtype="float64"))
# This shouldn't convert the dtype, because it's already a `Variable`
# with a set dtype
res = as_tensor_variable(x, dtype="int64")
assert res.type.dtype == "float64"
def test_memmap(self):
inp = np.random.rand(4, 3)
_, fname = mkstemp()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论