提交 36df3798 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): Correlty report the elemwise output type

上级 849c3b86
...@@ -604,12 +604,16 @@ def _vectorized( ...@@ -604,12 +604,16 @@ def _vectorized(
builder, sig.return_type, [out._getvalue() for out in outputs] builder, sig.return_type, [out._getvalue() for out in outputs]
) )
ret_type = types.Tuple( ret_types = [
[ types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C") for dtype in output_dtypes
for dtype in output_dtypes ]
]
) for output_idx, input_idx in inplace_pattern:
ret_types[output_idx] = input_types[input_idx]
ret_type = types.Tuple(ret_types)
if len(output_dtypes) == 1: if len(output_dtypes) == 1:
ret_type = ret_type.types[0] ret_type = ret_type.types[0]
sig = ret_type(*arg_types) sig = ret_type(*arg_types)
......
...@@ -605,3 +605,18 @@ def test_fused_elemwise_benchmark(benchmark): ...@@ -605,3 +605,18 @@ def test_fused_elemwise_benchmark(benchmark):
# JIT compile first # JIT compile first
func() func()
benchmark(func) benchmark(func)
def test_elemwise_out_type():
# Create a graph with an elemwise
# Ravel failes if the elemwise output type is reported incorrectly
x = at.matrix()
y = (2 * x).ravel()
# Pass in the input as mutable, to trigger the inplace rewrites
func = pytensor.function([pytensor.In(x, mutable=True)], y, mode="NUMBA")
# Apply it to a numpy array that is neither C or F contigous
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
assert func(x_val).shape == (18,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论