提交 9e46e6b4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Handle constants in XTensorType.filter_variable

上级 cc4b77a0
......@@ -103,7 +103,7 @@ class XTensorType(Type, HasDataType, HasShape):
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
other = xtensor_constant(other)
other = XTensorConstant(type=self, data=other)
if self.is_super(other.type):
return other
......
......@@ -10,7 +10,7 @@ from xarray import DataArray
from pytensor.graph.basic import equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorType, as_xtensor
from pytensor.xtensor.type import XTensorConstant, XTensorType, as_xtensor
def test_xtensortype():
......@@ -77,6 +77,37 @@ def test_xtensortype_filter_variable():
x.type.filter_variable(z4)
def test_xtensortype_filter_variable_constant():
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float32")
valid_x = np.zeros((2, 3), dtype="float32")
res = x.type.filter_variable(valid_x)
assert isinstance(res, XTensorConstant) and res.type == x.type
# Upcasting allowed
valid_x = np.zeros((2, 3), dtype="float16")
res = x.type.filter_variable(valid_x)
assert isinstance(res, XTensorConstant) and res.type == x.type
valid_x = np.zeros((2, 3), dtype="int16")
res = x.type.filter_variable(valid_x)
assert isinstance(res, XTensorConstant) and res.type == x.type
# Downcasting not allowed
invalid_x = np.zeros((2, 3), dtype="float64")
with pytest.raises(TypeError):
x.type.filter_variable(invalid_x)
invalid_x = np.zeros((2, 3), dtype="int32")
with pytest.raises(TypeError):
x.type.filter_variable(invalid_x)
# non_array types are fine
valid_x = [[0, 0, 0], [0, 0, 0]]
res = x.type.filter_variable(valid_x)
assert isinstance(res, XTensorConstant) and res.type == x.type
def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论