提交 36df3798 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

fix(numba): Correlty report the elemwise output type

上级 849c3b86
......@@ -604,12 +604,16 @@ def _vectorized(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
ret_type = types.Tuple(
[
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
)
ret_types = [
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
for dtype in output_dtypes
]
for output_idx, input_idx in inplace_pattern:
ret_types[output_idx] = input_types[input_idx]
ret_type = types.Tuple(ret_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
......
......@@ -605,3 +605,18 @@ def test_fused_elemwise_benchmark(benchmark):
# JIT compile first
func()
benchmark(func)
def test_elemwise_out_type():
# Create a graph with an elemwise
# Ravel failes if the elemwise output type is reported incorrectly
x = at.matrix()
y = (2 * x).ravel()
# Pass in the input as mutable, to trigger the inplace rewrites
func = pytensor.function([pytensor.In(x, mutable=True)], y, mode="NUMBA")
# Apply it to a numpy array that is neither C or F contigous
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
assert func(x_val).shape == (18,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论