提交 6228d023 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow NumPy dtype arguments in cast and tensor/scalar type constructors

上级 4572ae48
......@@ -332,8 +332,11 @@ class Scalar(CType):
ndim = 0
def __init__(self, dtype):
if dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype
self.dtype_specs() # error checking
......
......@@ -923,18 +923,21 @@ _cast_mapping = {
def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX
dtype_name = np.dtype(dtype).name
_x = as_tensor_variable(x)
if _x.type.dtype == dtype:
if _x.type.dtype == dtype_name:
return _x
if _x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
if _x.type.dtype.startswith("complex") and not dtype_name.startswith("complex"):
raise TypeError(
"Casting from complex to real is ambiguous: consider real(), "
"imag(), angle() or abs()"
)
return _cast_mapping[dtype](x)
return _cast_mapping[dtype_name](x)
##########################
......
......@@ -75,9 +75,11 @@ class TensorType(CType):
"""
def __init__(self, dtype, broadcastable, name=None):
self.dtype = str(dtype)
if self.dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
self.dtype = config.floatX
else:
self.dtype = np.dtype(dtype).name
# broadcastable is immutable, and all elements are either
# True or False
self.broadcastable = tuple(bool(b) for b in broadcastable)
......
......@@ -13,6 +13,11 @@ from aesara.scalar.basic import (
)
def test_numpy_dtype():
test_type = Scalar(np.int32)
assert test_type.dtype == "int32"
def test_div_types():
a = int8()
b = int32()
......
......@@ -851,6 +851,12 @@ def test_identity():
class TestCast:
def test_can_use_numpy_types(self):
x = vector(dtype=np.int32)
y = cast(x, np.int64)
f = function([x], y)
assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64
def test_good_between_real_types(self):
good = itertools.chain(
multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
......
......@@ -8,6 +8,11 @@ from aesara.configdefaults import config
from aesara.tensor.type import TensorType
def test_numpy_dtype():
test_type = TensorType(np.int32, [])
assert test_type.dtype == "int32"
def test_filter_variable():
test_type = TensorType(config.floatX, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论