提交 6228d023 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow NumPy dtype arguments in cast and tensor/scalar type constructors

上级 4572ae48
...@@ -332,8 +332,11 @@ class Scalar(CType): ...@@ -332,8 +332,11 @@ class Scalar(CType):
ndim = 0 ndim = 0
def __init__(self, dtype): def __init__(self, dtype):
if dtype == "floatX": if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype self.dtype = dtype
self.dtype_specs() # error checking self.dtype_specs() # error checking
......
...@@ -923,18 +923,21 @@ _cast_mapping = { ...@@ -923,18 +923,21 @@ _cast_mapping = {
def cast(x, dtype): def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`.""" """Symbolically cast `x` to a Tensor of type `dtype`."""
if dtype == "floatX":
if isinstance(dtype, str) and dtype == "floatX":
dtype = config.floatX dtype = config.floatX
dtype_name = np.dtype(dtype).name
_x = as_tensor_variable(x) _x = as_tensor_variable(x)
if _x.type.dtype == dtype: if _x.type.dtype == dtype_name:
return _x return _x
if _x.type.dtype.startswith("complex") and not dtype.startswith("complex"): if _x.type.dtype.startswith("complex") and not dtype_name.startswith("complex"):
raise TypeError( raise TypeError(
"Casting from complex to real is ambiguous: consider real(), " "Casting from complex to real is ambiguous: consider real(), "
"imag(), angle() or abs()" "imag(), angle() or abs()"
) )
return _cast_mapping[dtype](x) return _cast_mapping[dtype_name](x)
########################## ##########################
......
...@@ -75,9 +75,11 @@ class TensorType(CType): ...@@ -75,9 +75,11 @@ class TensorType(CType):
""" """
def __init__(self, dtype, broadcastable, name=None): def __init__(self, dtype, broadcastable, name=None):
self.dtype = str(dtype) if isinstance(dtype, str) and dtype == "floatX":
if self.dtype == "floatX":
self.dtype = config.floatX self.dtype = config.floatX
else:
self.dtype = np.dtype(dtype).name
# broadcastable is immutable, and all elements are either # broadcastable is immutable, and all elements are either
# True or False # True or False
self.broadcastable = tuple(bool(b) for b in broadcastable) self.broadcastable = tuple(bool(b) for b in broadcastable)
......
...@@ -13,6 +13,11 @@ from aesara.scalar.basic import ( ...@@ -13,6 +13,11 @@ from aesara.scalar.basic import (
) )
def test_numpy_dtype():
test_type = Scalar(np.int32)
assert test_type.dtype == "int32"
def test_div_types(): def test_div_types():
a = int8() a = int8()
b = int32() b = int32()
......
...@@ -851,6 +851,12 @@ def test_identity(): ...@@ -851,6 +851,12 @@ def test_identity():
class TestCast: class TestCast:
def test_can_use_numpy_types(self):
x = vector(dtype=np.int32)
y = cast(x, np.int64)
f = function([x], y)
assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64
def test_good_between_real_types(self): def test_good_between_real_types(self):
good = itertools.chain( good = itertools.chain(
multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES), multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES),
......
...@@ -8,6 +8,11 @@ from aesara.configdefaults import config ...@@ -8,6 +8,11 @@ from aesara.configdefaults import config
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
def test_numpy_dtype():
test_type = TensorType(np.int32, [])
assert test_type.dtype == "int32"
def test_filter_variable(): def test_filter_variable():
test_type = TensorType(config.floatX, []) test_type = TensorType(config.floatX, [])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论