提交 4ce8ef3c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix xtensor broadcast dtypes

上级 9f98757c
...@@ -519,7 +519,6 @@ class Broadcast(XOp): ...@@ -519,7 +519,6 @@ class Broadcast(XOp):
broadcast_dims = tuple(dims_and_shape.keys()) broadcast_dims = tuple(dims_and_shape.keys())
broadcast_shape = tuple(dims_and_shape.values()) broadcast_shape = tuple(dims_and_shape.values())
dtype = upcast(*[x.type.dtype for x in inputs])
outputs = [] outputs = []
for x in inputs: for x in inputs:
...@@ -530,7 +529,7 @@ class Broadcast(XOp): ...@@ -530,7 +529,7 @@ class Broadcast(XOp):
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims) excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)
output = xtensor( output = xtensor(
dtype=dtype, dtype=x.type.dtype,
shape=broadcast_shape + excluded_shape, shape=broadcast_shape + excluded_shape,
dims=broadcast_dims + excluded_dims, dims=broadcast_dims + excluded_dims,
) )
......
...@@ -635,6 +635,27 @@ class TestBroadcast: ...@@ -635,6 +635,27 @@ class TestBroadcast:
for res, expected_res in zip(results, expected_results, strict=True): for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res) xr_assert_allclose(res, expected_res)
def test_mixed_dtypes(self):
x = xtensor("x", dims=("a", "b"), shape=(3, 4), dtype="float64")
y = xtensor("y", dims=("c", "d"), shape=(5, 6), dtype="int64")
z = xtensor("z", dims=("b", "d"), shape=(4, 6), dtype="int32")
x_bcast, y_bcast, z_bcast = broadcast(x, y, z)
assert x_bcast.dtype == x.dtype
assert y_bcast.dtype == y.dtype
assert z_bcast.dtype == z.dtype
fn = xr_function([x, y, z], [x_bcast, y_bcast, z_bcast])
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
results = fn(x_test, y_test, z_test)
expected_results = xr_broadcast(x_test, y_test, z_test)
for res, exp_res in zip(results, expected_results, strict=True):
assert res.dtype == exp_res.dtype
xr_assert_allclose(res, exp_res)
def test_full_like(): def test_full_like():
"""Test full_like function, comparing with xarray's full_like.""" """Test full_like function, comparing with xarray's full_like."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论